From 9855c89e5b676cf694309b290713b5147347e4bc Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Tue, 13 Oct 2015 15:52:51 +0800 Subject: [PATCH 01/54] expression: add date_add func skeleton. --- expression/date_add.go | 54 ++++++++++++++++++++++++++++++++++++++++++ expression/extract.go | 2 +- expression/visitor.go | 21 +++++++++++++++- 3 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 expression/date_add.go diff --git a/expression/date_add.go b/expression/date_add.go new file mode 100644 index 0000000000..60985325d9 --- /dev/null +++ b/expression/date_add.go @@ -0,0 +1,54 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "github.com/pingcap/tidb/context" +) + +// DateAdd is for time date_add function. +// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add +type DateAdd struct { + Unit string + Date Expression + Interval Expression +} + +// Clone implements the Expression Clone interface. +func (e *DateAdd) Clone() Expression { + n := *e + return &n +} + +// Eval implements the Expression Eval interface. +func (e *DateAdd) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { + // TODO + return nil, nil +} + +// IsStatic implements the Expression IsStatic interface. +func (e *DateAdd) IsStatic() bool { + return e.Date.IsStatic() && e.Interval.IsStatic() +} + +// String implements the Expression String interface. +func (e *DateAdd) String() string { + // TODO + return "" +} + +// Accept implements the Visitor Accept interface. +func (e *DateAdd) Accept(v Visitor) (Expression, error) { + return v.VisitDateAdd(e) +} diff --git a/expression/extract.go b/expression/extract.go index 04d6c1f9e1..b36285ebca 100644 --- a/expression/extract.go +++ b/expression/extract.go @@ -24,7 +24,7 @@ import ( ) // Extract is for time extract function. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract +// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract type Extract struct { Unit string Date Expression diff --git a/expression/visitor.go b/expression/visitor.go index b6b7c04bec..6ec16f765f 100644 --- a/expression/visitor.go +++ b/expression/visitor.go @@ -106,6 +106,9 @@ type Visitor interface { // VisitFunctionTrim visits FunctionTrim expression. VisitFunctionTrim(v *FunctionTrim) (Expression, error) + + // VisitDateAdd visits DateAdd expression. + VisitDateAdd(da *DateAdd) (Expression, error) } // BaseVisitor is the base implementation of Visitor. @@ -456,7 +459,7 @@ func (bv *BaseVisitor) VisitWhenClause(w *WhenClause) (Expression, error) { return w, nil } -// VisitExtract implements Visitor +// VisitExtract implements Visitor interface. func (bv *BaseVisitor) VisitExtract(v *Extract) (Expression, error) { var err error v.Date, err = v.Date.Accept(bv.V) @@ -478,3 +481,19 @@ func (bv *BaseVisitor) VisitFunctionTrim(ss *FunctionTrim) (Expression, error) { } return ss, nil } + +// VisitDateAdd implements Visitor interface. +func (bv *BaseVisitor) VisitDateAdd(da *DateAdd) (Expression, error) { + var err error + da.Date, err = da.Date.Accept(bv.V) + if err != nil { + return da, errors.Trace(err) + } + + da.Interval, err = da.Interval.Accept(bv.V) + if err != nil { + return da, errors.Trace(err) + } + + return da, nil +} From 14c210bbedac55232a3265ce26ce10d8ba1d4652 Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Tue, 13 Oct 2015 16:39:47 +0800 Subject: [PATCH 02/54] parser: add add_date parser support. --- parser/parser.y | 11 ++++++++++- parser/scanner.l | 5 +++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/parser/parser.y b/parser/parser.y index 0f9801ffca..73d7d81e1c 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -112,6 +112,7 @@ import ( currentUser "CURRENT_USER" database "DATABASE" databases "DATABASES" + dateAdd "DATE_ADD" day "DAY" dayofmonth "DAYOFMONTH" dayofweek "DAYOFWEEK" @@ -162,6 +163,7 @@ import ( index "INDEX" inner "INNER" insert "INSERT" + interval "INTERVAL" into "INTO" is "IS" join "JOIN" @@ -696,7 +698,6 @@ ColumnPosition: } } - AlterSpecificationList: AlterSpecification { @@ -2228,6 +2229,14 @@ FunctionCallNonKeyword: return 1 } } +| "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' + { + $$ = &expression.DateAdd{ + Unit: $7.(string), + Date: $3.(expression.Expression), + Interval: $6.(expression.Expression), + } + } | "EXTRACT" '(' TimeUnit "FROM" Expression ')' { $$ = &expression.Extract{ diff --git a/parser/scanner.l b/parser/scanner.l index b84c6f4cfb..2a5b783b76 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -284,6 +284,7 @@ current_date {c}{u}{r}{r}{e}{n}{t}_{d}{a}{t}{e} current_user {c}{u}{r}{r}{e}{n}{t}_{u}{s}{e}{r} database {d}{a}{t}{a}{b}{a}{s}{e} databases {d}{a}{t}{a}{b}{a}{s}{e}{s} +date_add {d}{a}{t}{e}_{a}{d}{d} day {d}{a}{y} dayofweek {d}{a}{y}{o}{f}{w}{e}{e}{k} dayofmonth {d}{a}{y}{o}{f}{m}{o}{n}{t}{h} @@ -330,6 +331,7 @@ in {i}{n} index {i}{n}{d}{e}{x} inner {i}{n}{n}{e}{r} insert {i}{n}{s}{e}{r}{t} +interval {i}{n}{t}{e}{r}{v}{a}{l} into {i}{n}{t}{o} is {i}{s} join {j}{o}{i}{n} @@ -627,6 +629,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {database} lval.item = string(l.val) return database {databases} return databases +{date_add} lval.item = string(l.val) + return dateAdd {day} lval.item = string(l.val) return day {dayofweek} lval.item = string(l.val) @@ -707,6 +711,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {index} return index {inner} return inner {insert} return insert +{interval} return interval {into} return into {in} return in {is} return is From ff8546a539072878461016114ba519d8726b8fe8 Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Tue, 13 Oct 2015 16:40:10 +0800 Subject: [PATCH 03/54] parser: add parser test. --- parser/parser_test.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/parser/parser_test.go b/parser/parser_test.go index 12677906ef..b383f27f6c 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -401,6 +401,28 @@ func (s *testParserSuite) TestBuiltin(c *C) { {`SELECT TRIM(LEADING 'x' FROM 'xxxbarxxx');`, true}, {`SELECT TRIM(BOTH 'x' FROM 'xxxbarxxx');`, true}, {`SELECT TRIM(TRAILING 'xyz' FROM 'barxxyz');`, true}, + + // For date_add + {`select date_add("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 day)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 week)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 month)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 year)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, } s.RunTest(c, table) } From fe6f12f9e5d329cfa5680d510dc7c6038ed0b09a Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Tue, 13 Oct 2015 20:35:55 +0800 Subject: [PATCH 04/54] expression: complete date add function. --- expression/date_add.go | 68 ++++++-- expression/extract.go | 69 +------- expression/helper.go | 359 ++++++++++++++++++++++++++++++++++++++ expression/helper_test.go | 1 - 4 files changed, 416 insertions(+), 81 deletions(-) diff --git a/expression/date_add.go b/expression/date_add.go index 60985325d9..4e54e3c777 100644 --- a/expression/date_add.go +++ b/expression/date_add.go @@ -14,7 +14,13 @@ package expression import ( + "fmt" + "strings" + + "github.com/juju/errors" "github.com/pingcap/tidb/context" + mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/util/types" ) // DateAdd is for time date_add function. @@ -26,29 +32,67 @@ type DateAdd struct { } // Clone implements the Expression Clone interface. -func (e *DateAdd) Clone() Expression { - n := *e +func (da *DateAdd) Clone() Expression { + n := *da return &n } // Eval implements the Expression Eval interface. -func (e *DateAdd) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { - // TODO - return nil, nil +func (da *DateAdd) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { + dv, err := da.Date.Eval(ctx, args) + if dv == nil || err != nil { + return nil, errors.Trace(err) + } + + sv, err := types.ToString(dv) + if err != nil { + return nil, errors.Trace(err) + } + + f := types.NewFieldType(mysql.TypeDatetime) + f.Decimal = mysql.MaxFsp + + dv, err = types.Convert(sv, f) + if dv == nil || err != nil { + return nil, errors.Trace(err) + } + + t, ok := dv.(mysql.Time) + if !ok { + return nil, errors.Errorf("need time type, but got %T", dv) + } + + iv, err := da.Interval.Eval(ctx, args) + if iv == nil || err != nil { + return nil, errors.Trace(err) + } + + format, err := types.ToString(iv) + if err != nil { + return nil, errors.Trace(err) + } + + years, months, days, durations, err := extractTimeValue(da.Unit, format) + if err != nil { + return nil, errors.Trace(err) + } + + t.Time = t.Time.Add(durations) + t.Time = t.Time.AddDate(int(years), int(months), int(days)) + return t, nil } // IsStatic implements the Expression IsStatic interface. -func (e *DateAdd) IsStatic() bool { - return e.Date.IsStatic() && e.Interval.IsStatic() +func (da *DateAdd) IsStatic() bool { + return da.Date.IsStatic() && da.Interval.IsStatic() } // String implements the Expression String interface. -func (e *DateAdd) String() string { - // TODO - return "" +func (da *DateAdd) String() string { + return fmt.Sprintf("DATE_ADD(%s, INTERVAL %s %s)", da.Date, da.Interval, strings.ToUpper(da.Unit)) } // Accept implements the Visitor Accept interface. -func (e *DateAdd) Accept(v Visitor) (Expression, error) { - return v.VisitDateAdd(e) +func (da *DateAdd) Accept(v Visitor) (Expression, error) { + return v.VisitDateAdd(da) } diff --git a/expression/extract.go b/expression/extract.go index b36285ebca..c5c34d8886 100644 --- a/expression/extract.go +++ b/expression/extract.go @@ -56,7 +56,7 @@ func (e *Extract) Eval(ctx context.Context, args map[interface{}]interface{}) (i return nil, errors.Errorf("need time type, but got %T", v) } - n, err1 := extractTime(e.Unit, t) + n, err1 := extractTimeNum(e.Unit, t) if err1 != nil { return nil, errors.Trace(err1) } @@ -78,70 +78,3 @@ func (e *Extract) String() string { func (e *Extract) Accept(v Visitor) (Expression, error) { return v.VisitExtract(e) } - -func extractTime(unit string, t mysql.Time) (int64, error) { - switch strings.ToUpper(unit) { - case "MICROSECOND": - return int64(t.Nanosecond() / 1000), nil - case "SECOND": - return int64(t.Second()), nil - case "MINUTE": - return int64(t.Minute()), nil - case "HOUR": - return int64(t.Hour()), nil - case "DAY": - return int64(t.Day()), nil - case "WEEK": - _, week := t.ISOWeek() - return int64(week), nil - case "MONTH": - return int64(t.Month()), nil - case "QUARTER": - m := int64(t.Month()) - // 1 - 3 -> 1 - // 4 - 6 -> 2 - // 7 - 9 -> 3 - // 10 - 12 -> 4 - return (m + 2) / 3, nil - case "YEAR": - return int64(t.Year()), nil - case "SECOND_MICROSECOND": - return int64(t.Second())*1000000 + int64(t.Nanosecond())/1000, nil - case "MINUTE_MICROSECOND": - _, m, s := t.Clock() - return int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil - case "MINUTE_SECOND": - _, m, s := t.Clock() - return int64(m*100 + s), nil - case "HOUR_MICROSECOND": - h, m, s := t.Clock() - return int64(h)*10000000000 + int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil - case "HOUR_SECOND": - h, m, s := t.Clock() - return int64(h)*10000 + int64(m)*100 + int64(s), nil - case "HOUR_MINUTE": - h, m, _ := t.Clock() - return int64(h)*100 + int64(m), nil - case "DAY_MICROSECOND": - h, m, s := t.Clock() - d := t.Day() - return int64(d*1000000+h*10000+m*100+s)*1000000 + int64(t.Nanosecond())/1000, nil - case "DAY_SECOND": - h, m, s := t.Clock() - d := t.Day() - return int64(d)*1000000 + int64(h)*10000 + int64(m)*100 + int64(s), nil - case "DAY_MINUTE": - h, m, _ := t.Clock() - d := t.Day() - return int64(d)*10000 + int64(h)*100 + int64(m), nil - case "DAY_HOUR": - h, _, _ := t.Clock() - d := t.Day() - return int64(d)*100 + int64(h), nil - case "YEAR_MONTH": - y, m, _ := t.Date() - return int64(y)*100 + int64(m), nil - default: - return 0, errors.Errorf("invalid unit %s", unit) - } -} diff --git a/expression/helper.go b/expression/helper.go index be573e3fba..4ff80dcd0c 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -394,3 +394,362 @@ func hasSameColumnCount(ctx context.Context, e Expression, args ...Expression) e return nil } + +func extractTimeNum(unit string, t mysql.Time) (int64, error) { + switch strings.ToUpper(unit) { + case "MICROSECOND": + return int64(t.Nanosecond() / 1000), nil + case "SECOND": + return int64(t.Second()), nil + case "MINUTE": + return int64(t.Minute()), nil + case "HOUR": + return int64(t.Hour()), nil + case "DAY": + return int64(t.Day()), nil + case "WEEK": + _, week := t.ISOWeek() + return int64(week), nil + case "MONTH": + return int64(t.Month()), nil + case "QUARTER": + m := int64(t.Month()) + // 1 - 3 -> 1 + // 4 - 6 -> 2 + // 7 - 9 -> 3 + // 10 - 12 -> 4 + return (m + 2) / 3, nil + case "YEAR": + return int64(t.Year()), nil + case "SECOND_MICROSECOND": + return int64(t.Second())*1000000 + int64(t.Nanosecond())/1000, nil + case "MINUTE_MICROSECOND": + _, m, s := t.Clock() + return int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil + case "MINUTE_SECOND": + _, m, s := t.Clock() + return int64(m*100 + s), nil + case "HOUR_MICROSECOND": + h, m, s := t.Clock() + return int64(h)*10000000000 + int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil + case "HOUR_SECOND": + h, m, s := t.Clock() + return int64(h)*10000 + int64(m)*100 + int64(s), nil + case "HOUR_MINUTE": + h, m, _ := t.Clock() + return int64(h)*100 + int64(m), nil + case "DAY_MICROSECOND": + h, m, s := t.Clock() + d := t.Day() + return int64(d*1000000+h*10000+m*100+s)*1000000 + int64(t.Nanosecond())/1000, nil + case "DAY_SECOND": + h, m, s := t.Clock() + d := t.Day() + return int64(d)*1000000 + int64(h)*10000 + int64(m)*100 + int64(s), nil + case "DAY_MINUTE": + h, m, _ := t.Clock() + d := t.Day() + return int64(d)*10000 + int64(h)*100 + int64(m), nil + case "DAY_HOUR": + h, _, _ := t.Clock() + d := t.Day() + return int64(d)*100 + int64(h), nil + case "YEAR_MONTH": + y, m, _ := t.Date() + return int64(y)*100 + int64(m), nil + default: + return 0, errors.Errorf("invalid unit %s", unit) + } +} + +func extractSingleTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) { + iv, err := strconv.ParseInt(format, 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + v := time.Duration(iv) + switch strings.ToUpper(unit) { + case "MICROSECOND": + return 0, 0, 0, v * time.Microsecond, nil + case "SECOND": + return 0, 0, 0, v * time.Second, nil + case "MINUTE": + return 0, 0, 0, v * time.Minute, nil + case "HOUR": + return 0, 0, 0, v * time.Hour, nil + case "DAY": + return 0, 0, 1, 0, nil + case "WEEK": + return 0, 0, 7, 0, nil + case "MONTH": + return 0, 1, 0, 0, nil + case "QUARTER": + return 0, 3, 0, 0, nil + case "YEAR": + return 1, 0, 0, 0, nil + } + + return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit) +} + +// Format is `SS.FFFFFF`. +func extractSecondMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ".") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + seconds, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + microsecond, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(seconds)*time.Second + time.Duration(microsecond)*time.Microsecond, nil +} + +// Format is `MM:SS.FFFFFF`. +func extractMinuteMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractSecondMicrosecond(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Trace(err) + } + + return 0, 0, 0, time.Duration(minutes)*time.Minute + value, nil +} + +// Format is `MM:SS`. +func extractMinuteSecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + seconds, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil +} + +// Format is `HH:MM:SS.FFFFFF`. +func extractHourMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 3 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractSecondMicrosecond(fields[2]) + if err != nil { + return 0, 0, 0, 0, errors.Trace(err) + } + + return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + value, nil +} + +// Format is `HH:MM:SS`. +func extractHourSecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 3 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + seconds, err := strconv.ParseInt(fields[2], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil +} + +// Format is `HH:MM`. +func extractHourMinute(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute, nil +} + +// Format is `DD HH:MM:SS.FFFFFF`. +func extractDayMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractHourMicrosecond(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, value, nil +} + +// Format is `DD HH:MM:SS`. +func extractDaySecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractHourSecond(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, value, nil +} + +// Format is `DD HH:MM`. +func extractDayMinute(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractHourMinute(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, value, nil +} + +// Format is `DD HH`. +func extractDayHour(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, time.Duration(hours) * time.Hour, nil +} + +// Format is `YYYY-MM`. +func extractYearMonth(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, time.Duration(hours) * time.Hour, nil +} + +func extractTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) { + switch strings.ToUpper(unit) { + case "MICROSECOND", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR": + return extractSingleTimeValue(unit, format) + case "SECOND_MICROSECOND": + return extractSecondMicrosecond(format) + case "MINUTE_MICROSECOND": + return extractMinuteMicrosecond(format) + case "MINUTE_SECOND": + return extractMinuteSecond(format) + case "HOUR_MICROSECOND": + return extractHourMicrosecond(format) + case "HOUR_SECOND": + return extractHourSecond(format) + case "HOUR_MINUTE": + return extractHourMinute(format) + case "DAY_MICROSECOND": + return extractDayMicrosecond(format) + case "DAY_SECOND": + return extractDaySecond(format) + case "DAY_MINUTE": + return extractDayMinute(format) + case "DAY_HOUR": + return extractDayHour(format) + case "YEAR_MONTH": + return extractYearMonth(format) + default: + return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit) + } +} diff --git a/expression/helper_test.go b/expression/helper_test.go index 6bda7bc2ac..1eacac2302 100644 --- a/expression/helper_test.go +++ b/expression/helper_test.go @@ -5,7 +5,6 @@ import ( "time" . "github.com/pingcap/check" - "github.com/pingcap/tidb/model" mysql "github.com/pingcap/tidb/mysqldef" "github.com/pingcap/tidb/parser/opcode" From fec9aa006f689b2c3d9442d60dc640aa66c97e8a Mon Sep 17 00:00:00 2001 From: dongxu Date: Tue, 13 Oct 2015 20:58:01 +0800 Subject: [PATCH 05/54] mvcc: compact implementation backup --- store/localstore/kv.go | 2 + store/localstore/local_gc.go | 76 ++++++++++++++++++++++++++++++++++++ store/localstore/txn.go | 11 ++++-- 3 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 store/localstore/local_gc.go diff --git a/store/localstore/kv.go b/store/localstore/kv.go index 58934543fc..9871f7fda5 100644 --- a/store/localstore/kv.go +++ b/store/localstore/kv.go @@ -38,6 +38,7 @@ type dbStore struct { keysLocked map[string]uint64 uuid string path string + gc *localstoreGC } type storeCache struct { @@ -85,6 +86,7 @@ func (d Driver) Open(schema string) (kv.Storage, error) { uuid: uuid.NewV4().String(), path: schema, db: db, + gc: newLocalGC(), } mc.cache[schema] = s diff --git a/store/localstore/local_gc.go b/store/localstore/local_gc.go new file mode 100644 index 0000000000..eafa18b97f --- /dev/null +++ b/store/localstore/local_gc.go @@ -0,0 +1,76 @@ +package localstore + +import ( + "container/list" + "sync" + "time" + + "github.com/ngaut/log" + "github.com/pingcap/tidb/kv" +) + +var _ kv.GC = (*localstoreGC)(nil) + +type localstoreGC struct { + mu sync.Mutex + recentKeys map[string]struct{} + stopChan chan struct{} + ticker *time.Ticker +} + +func (gc *localstoreGC) OnSet(k kv.Key) { + gc.mu.Lock() + defer gc.mu.Unlock() + gc.recentKeys[string(k)] = struct{}{} +} + +func (gc *localstoreGC) OnGet(k kv.Key) { + gc.mu.Lock() + defer gc.mu.Unlock() + gc.recentKeys[string(k)] = struct{}{} +} + +func (gc *localstoreGC) Do(k kv.Key) { + // TODO + log.Debugf("gc key: %q", k) +} + +func (gc *localstoreGC) Start() { + go func() { + for { + select { + case <-gc.stopChan: + break + case <-gc.ticker.C: + l := list.New() + gc.mu.Lock() + for k, _ := range gc.recentKeys { + // get recent keys list + l.PushBack(k) + } + // clean recentKeys + gc.recentKeys = map[string]struct{}{} + gc.mu.Unlock() + // Do GC + for e := l.Front(); e != nil; e = e.Next() { + gc.Do([]byte(e.Value.(string))) + } + } + } + }() +} + +func (gc *localstoreGC) Stop() { + gc.stopChan <- struct{}{} + gc.ticker.Stop() + close(gc.stopChan) +} + +func newLocalGC() *localstoreGC { + return &localstoreGC{ + recentKeys: map[string]struct{}{}, + stopChan: make(chan struct{}), + // TODO hard code + ticker: time.NewTicker(time.Second * 1), + } +} diff --git a/store/localstore/txn.go b/store/localstore/txn.go index ae315167c8..ddceb5b7be 100644 --- a/store/localstore/txn.go +++ b/store/localstore/txn.go @@ -95,7 +95,6 @@ func (txn *dbTxn) Inc(k kv.Key, step int64) (int64, error) { if err != nil { return 0, errors.Trace(err) } - return intVal, nil } @@ -108,7 +107,6 @@ func (txn *dbTxn) GetInt64(k kv.Key) (int64, error) { if err != nil { return 0, errors.Trace(err) } - intVal, err := strconv.ParseInt(string(val), 10, 0) return intVal, errors.Trace(err) } @@ -130,7 +128,7 @@ func (txn *dbTxn) Get(k kv.Key) ([]byte, error) { if len(val) == 0 { return nil, errors.Trace(kv.ErrNotExist) } - + txn.store.gc.OnGet(k) return val, nil } @@ -143,7 +141,12 @@ func (txn *dbTxn) Set(k kv.Key, data []byte) error { log.Debugf("set key:%q, txn:%d", k, txn.tID) k = kv.EncodeKey(k) - return txn.UnionStore.Set(k, data) + err := txn.UnionStore.Set(k, data) + if err != nil { + return errors.Trace(err) + } + txn.store.gc.OnSet(k) + return nil } func (txn *dbTxn) Seek(k kv.Key, fnKeyCmp func(kv.Key) bool) (kv.Iterator, error) { From d049936db372cb84f5c9327908b6953d7a40cbf1 Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Wed, 14 Oct 2015 15:14:22 +0800 Subject: [PATCH 06/54] *: move extract functions to time.go --- expression/extract.go | 2 +- expression/helper.go | 359 ----------------------------------------- mysqldef/fsp.go | 10 ++ mysqldef/time.go | 361 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 372 insertions(+), 360 deletions(-) diff --git a/expression/extract.go b/expression/extract.go index c5c34d8886..abd0b9af57 100644 --- a/expression/extract.go +++ b/expression/extract.go @@ -56,7 +56,7 @@ func (e *Extract) Eval(ctx context.Context, args map[interface{}]interface{}) (i return nil, errors.Errorf("need time type, but got %T", v) } - n, err1 := extractTimeNum(e.Unit, t) + n, err1 := mysql.ExtractTimeNum(e.Unit, t) if err1 != nil { return nil, errors.Trace(err1) } diff --git a/expression/helper.go b/expression/helper.go index 4ff80dcd0c..be573e3fba 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -394,362 +394,3 @@ func hasSameColumnCount(ctx context.Context, e Expression, args ...Expression) e return nil } - -func extractTimeNum(unit string, t mysql.Time) (int64, error) { - switch strings.ToUpper(unit) { - case "MICROSECOND": - return int64(t.Nanosecond() / 1000), nil - case "SECOND": - return int64(t.Second()), nil - case "MINUTE": - return int64(t.Minute()), nil - case "HOUR": - return int64(t.Hour()), nil - case "DAY": - return int64(t.Day()), nil - case "WEEK": - _, week := t.ISOWeek() - return int64(week), nil - case "MONTH": - return int64(t.Month()), nil - case "QUARTER": - m := int64(t.Month()) - // 1 - 3 -> 1 - // 4 - 6 -> 2 - // 7 - 9 -> 3 - // 10 - 12 -> 4 - return (m + 2) / 3, nil - case "YEAR": - return int64(t.Year()), nil - case "SECOND_MICROSECOND": - return int64(t.Second())*1000000 + int64(t.Nanosecond())/1000, nil - case "MINUTE_MICROSECOND": - _, m, s := t.Clock() - return int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil - case "MINUTE_SECOND": - _, m, s := t.Clock() - return int64(m*100 + s), nil - case "HOUR_MICROSECOND": - h, m, s := t.Clock() - return int64(h)*10000000000 + int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil - case "HOUR_SECOND": - h, m, s := t.Clock() - return int64(h)*10000 + int64(m)*100 + int64(s), nil - case "HOUR_MINUTE": - h, m, _ := t.Clock() - return int64(h)*100 + int64(m), nil - case "DAY_MICROSECOND": - h, m, s := t.Clock() - d := t.Day() - return int64(d*1000000+h*10000+m*100+s)*1000000 + int64(t.Nanosecond())/1000, nil - case "DAY_SECOND": - h, m, s := t.Clock() - d := t.Day() - return int64(d)*1000000 + int64(h)*10000 + int64(m)*100 + int64(s), nil - case "DAY_MINUTE": - h, m, _ := t.Clock() - d := t.Day() - return int64(d)*10000 + int64(h)*100 + int64(m), nil - case "DAY_HOUR": - h, _, _ := t.Clock() - d := t.Day() - return int64(d)*100 + int64(h), nil - case "YEAR_MONTH": - y, m, _ := t.Date() - return int64(y)*100 + int64(m), nil - default: - return 0, errors.Errorf("invalid unit %s", unit) - } -} - -func extractSingleTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) { - iv, err := strconv.ParseInt(format, 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - v := time.Duration(iv) - switch strings.ToUpper(unit) { - case "MICROSECOND": - return 0, 0, 0, v * time.Microsecond, nil - case "SECOND": - return 0, 0, 0, v * time.Second, nil - case "MINUTE": - return 0, 0, 0, v * time.Minute, nil - case "HOUR": - return 0, 0, 0, v * time.Hour, nil - case "DAY": - return 0, 0, 1, 0, nil - case "WEEK": - return 0, 0, 7, 0, nil - case "MONTH": - return 0, 1, 0, 0, nil - case "QUARTER": - return 0, 3, 0, 0, nil - case "YEAR": - return 1, 0, 0, 0, nil - } - - return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit) -} - -// Format is `SS.FFFFFF`. -func extractSecondMicrosecond(format string) (int64, int64, int64, time.Duration, error) { - fields := strings.Split(format, ".") - if len(fields) != 2 { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - seconds, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - microsecond, err := strconv.ParseInt(fields[1], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - return 0, 0, 0, time.Duration(seconds)*time.Second + time.Duration(microsecond)*time.Microsecond, nil -} - -// Format is `MM:SS.FFFFFF`. -func extractMinuteMicrosecond(format string) (int64, int64, int64, time.Duration, error) { - fields := strings.Split(format, ":") - if len(fields) != 2 { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - minutes, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - _, _, _, value, err := extractSecondMicrosecond(fields[1]) - if err != nil { - return 0, 0, 0, 0, errors.Trace(err) - } - - return 0, 0, 0, time.Duration(minutes)*time.Minute + value, nil -} - -// Format is `MM:SS`. -func extractMinuteSecond(format string) (int64, int64, int64, time.Duration, error) { - fields := strings.Split(format, ":") - if len(fields) != 2 { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - minutes, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - seconds, err := strconv.ParseInt(fields[1], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - return 0, 0, 0, time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil -} - -// Format is `HH:MM:SS.FFFFFF`. -func extractHourMicrosecond(format string) (int64, int64, int64, time.Duration, error) { - fields := strings.Split(format, ":") - if len(fields) != 3 { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - hours, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - minutes, err := strconv.ParseInt(fields[1], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - _, _, _, value, err := extractSecondMicrosecond(fields[2]) - if err != nil { - return 0, 0, 0, 0, errors.Trace(err) - } - - return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + value, nil -} - -// Format is `HH:MM:SS`. -func extractHourSecond(format string) (int64, int64, int64, time.Duration, error) { - fields := strings.Split(format, ":") - if len(fields) != 3 { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - hours, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - minutes, err := strconv.ParseInt(fields[1], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - seconds, err := strconv.ParseInt(fields[2], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil -} - -// Format is `HH:MM`. -func extractHourMinute(format string) (int64, int64, int64, time.Duration, error) { - fields := strings.Split(format, ":") - if len(fields) != 2 { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - hours, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - minutes, err := strconv.ParseInt(fields[1], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute, nil -} - -// Format is `DD HH:MM:SS.FFFFFF`. -func extractDayMicrosecond(format string) (int64, int64, int64, time.Duration, error) { - fields := strings.Split(format, " ") - if len(fields) != 2 { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - days, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - _, _, _, value, err := extractHourMicrosecond(fields[1]) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - return 0, 0, days, value, nil -} - -// Format is `DD HH:MM:SS`. -func extractDaySecond(format string) (int64, int64, int64, time.Duration, error) { - fields := strings.Split(format, " ") - if len(fields) != 2 { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - days, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - _, _, _, value, err := extractHourSecond(fields[1]) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - return 0, 0, days, value, nil -} - -// Format is `DD HH:MM`. -func extractDayMinute(format string) (int64, int64, int64, time.Duration, error) { - fields := strings.Split(format, " ") - if len(fields) != 2 { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - days, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - _, _, _, value, err := extractHourMinute(fields[1]) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - return 0, 0, days, value, nil -} - -// Format is `DD HH`. -func extractDayHour(format string) (int64, int64, int64, time.Duration, error) { - fields := strings.Split(format, " ") - if len(fields) != 2 { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - days, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - hours, err := strconv.ParseInt(fields[1], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - return 0, 0, days, time.Duration(hours) * time.Hour, nil -} - -// Format is `YYYY-MM`. -func extractYearMonth(format string) (int64, int64, int64, time.Duration, error) { - fields := strings.Split(format, " ") - if len(fields) != 2 { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - days, err := strconv.ParseInt(fields[0], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - hours, err := strconv.ParseInt(fields[1], 10, 64) - if err != nil { - return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) - } - - return 0, 0, days, time.Duration(hours) * time.Hour, nil -} - -func extractTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) { - switch strings.ToUpper(unit) { - case "MICROSECOND", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR": - return extractSingleTimeValue(unit, format) - case "SECOND_MICROSECOND": - return extractSecondMicrosecond(format) - case "MINUTE_MICROSECOND": - return extractMinuteMicrosecond(format) - case "MINUTE_SECOND": - return extractMinuteSecond(format) - case "HOUR_MICROSECOND": - return extractHourMicrosecond(format) - case "HOUR_SECOND": - return extractHourSecond(format) - case "HOUR_MINUTE": - return extractHourMinute(format) - case "DAY_MICROSECOND": - return extractDayMicrosecond(format) - case "DAY_SECOND": - return extractDaySecond(format) - case "DAY_MINUTE": - return extractDayMinute(format) - case "DAY_HOUR": - return extractDayHour(format) - case "YEAR_MONTH": - return extractYearMonth(format) - default: - return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit) - } -} diff --git a/mysqldef/fsp.go b/mysqldef/fsp.go index b8802e070f..46384d987e 100644 --- a/mysqldef/fsp.go +++ b/mysqldef/fsp.go @@ -80,3 +80,13 @@ func parseFrac(s string, fsp int) (int, error) { // 0.0312 round 2 -> 3 -> 30000 return int(round * math.Pow10(MaxFsp-fsp)), nil } + +// alignFrac is used to generate alignment frac, like `100` -> `100000` +func alignFrac(s string, fsp int) string { + sl := len(s) + if sl < fsp { + return s + strings.Repeat("0", fsp-sl) + } + + return s +} diff --git a/mysqldef/time.go b/mysqldef/time.go index bb01d19926..21c81455d9 100644 --- a/mysqldef/time.go +++ b/mysqldef/time.go @@ -979,3 +979,364 @@ func checkTimestamp(t Time) bool { return true } + +// ExtractTimeNum extracts time value number from time unit and format. +func ExtractTimeNum(unit string, t Time) (int64, error) { + switch strings.ToUpper(unit) { + case "MICROSECOND": + return int64(t.Nanosecond() / 1000), nil + case "SECOND": + return int64(t.Second()), nil + case "MINUTE": + return int64(t.Minute()), nil + case "HOUR": + return int64(t.Hour()), nil + case "DAY": + return int64(t.Day()), nil + case "WEEK": + _, week := t.ISOWeek() + return int64(week), nil + case "MONTH": + return int64(t.Month()), nil + case "QUARTER": + m := int64(t.Month()) + // 1 - 3 -> 1 + // 4 - 6 -> 2 + // 7 - 9 -> 3 + // 10 - 12 -> 4 + return (m + 2) / 3, nil + case "YEAR": + return int64(t.Year()), nil + case "SECOND_MICROSECOND": + return int64(t.Second())*1000000 + int64(t.Nanosecond())/1000, nil + case "MINUTE_MICROSECOND": + _, m, s := t.Clock() + return int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil + case "MINUTE_SECOND": + _, m, s := t.Clock() + return int64(m*100 + s), nil + case "HOUR_MICROSECOND": + h, m, s := t.Clock() + return int64(h)*10000000000 + int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil + case "HOUR_SECOND": + h, m, s := t.Clock() + return int64(h)*10000 + int64(m)*100 + int64(s), nil + case "HOUR_MINUTE": + h, m, _ := t.Clock() + return int64(h)*100 + int64(m), nil + case "DAY_MICROSECOND": + h, m, s := t.Clock() + d := t.Day() + return int64(d*1000000+h*10000+m*100+s)*1000000 + int64(t.Nanosecond())/1000, nil + case "DAY_SECOND": + h, m, s := t.Clock() + d := t.Day() + return int64(d)*1000000 + int64(h)*10000 + int64(m)*100 + int64(s), nil + case "DAY_MINUTE": + h, m, _ := t.Clock() + d := t.Day() + return int64(d)*10000 + int64(h)*100 + int64(m), nil + case "DAY_HOUR": + h, _, _ := t.Clock() + d := t.Day() + return int64(d)*100 + int64(h), nil + case "YEAR_MONTH": + y, m, _ := t.Date() + return int64(y)*100 + int64(m), nil + default: + return 0, errors.Errorf("invalid unit %s", unit) + } +} + +func extractSingleTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) { + iv, err := strconv.ParseInt(format, 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + v := time.Duration(iv) + switch strings.ToUpper(unit) { + case "MICROSECOND": + return 0, 0, 0, v * time.Microsecond, nil + case "SECOND": + return 0, 0, 0, v * time.Second, nil + case "MINUTE": + return 0, 0, 0, v * time.Minute, nil + case "HOUR": + return 0, 0, 0, v * time.Hour, nil + case "DAY": + return 0, 0, iv, 0, nil + case "WEEK": + return 0, 0, 7 * iv, 0, nil + case "MONTH": + return 0, iv, 0, 0, nil + case "QUARTER": + return 0, 3 * iv, 0, 0, nil + case "YEAR": + return iv, 0, 0, 0, nil + } + + return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit) +} + +// Format is `SS.FFFFFF`. +func extractSecondMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ".") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + seconds, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + microsecond, err := strconv.ParseInt(alignFrac(fields[1], MaxFsp), 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(seconds)*time.Second + time.Duration(microsecond)*time.Microsecond, nil +} + +// Format is `MM:SS.FFFFFF`. +func extractMinuteMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractSecondMicrosecond(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Trace(err) + } + + return 0, 0, 0, time.Duration(minutes)*time.Minute + value, nil +} + +// Format is `MM:SS`. +func extractMinuteSecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + seconds, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil +} + +// Format is `HH:MM:SS.FFFFFF`. +func extractHourMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 3 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractSecondMicrosecond(fields[2]) + if err != nil { + return 0, 0, 0, 0, errors.Trace(err) + } + + return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + value, nil +} + +// Format is `HH:MM:SS`. +func extractHourSecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 3 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + seconds, err := strconv.ParseInt(fields[2], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil +} + +// Format is `HH:MM`. +func extractHourMinute(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute, nil +} + +// Format is `DD HH:MM:SS.FFFFFF`. +func extractDayMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractHourMicrosecond(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, value, nil +} + +// Format is `DD HH:MM:SS`. +func extractDaySecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractHourSecond(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, value, nil +} + +// Format is `DD HH:MM`. +func extractDayMinute(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractHourMinute(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, value, nil +} + +// Format is `DD HH`. +func extractDayHour(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, time.Duration(hours) * time.Hour, nil +} + +// Format is `YYYY-MM`. +func extractYearMonth(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, "-") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + years, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + months, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return years, months, 0, 0, nil +} + +// ExtractTimeValue extracts time value from time unit and format. +func ExtractTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) { + switch strings.ToUpper(unit) { + case "MICROSECOND", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR": + return extractSingleTimeValue(unit, format) + case "SECOND_MICROSECOND": + return extractSecondMicrosecond(format) + case "MINUTE_MICROSECOND": + return extractMinuteMicrosecond(format) + case "MINUTE_SECOND": + return extractMinuteSecond(format) + case "HOUR_MICROSECOND": + return extractHourMicrosecond(format) + case "HOUR_SECOND": + return extractHourSecond(format) + case "HOUR_MINUTE": + return extractHourMinute(format) + case "DAY_MICROSECOND": + return extractDayMicrosecond(format) + case "DAY_SECOND": + return extractDaySecond(format) + case "DAY_MINUTE": + return extractDayMinute(format) + case "DAY_HOUR": + return extractDayHour(format) + case "YEAR_MONTH": + return extractYearMonth(format) + default: + return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit) + } +} From a1a274142b99b2db9463fc8eb21140fbaff9b31c Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Wed, 14 Oct 2015 15:14:52 +0800 Subject: [PATCH 07/54] expression: update date_add time output. --- expression/date_add.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/expression/date_add.go b/expression/date_add.go index 4e54e3c777..b474347f4c 100644 --- a/expression/date_add.go +++ b/expression/date_add.go @@ -72,13 +72,19 @@ func (da *DateAdd) Eval(ctx context.Context, args map[interface{}]interface{}) ( return nil, errors.Trace(err) } - years, months, days, durations, err := extractTimeValue(da.Unit, format) + years, months, days, durations, err := mysql.ExtractTimeValue(da.Unit, strings.TrimSpace(format)) if err != nil { return nil, errors.Trace(err) } t.Time = t.Time.Add(durations) t.Time = t.Time.AddDate(int(years), int(months), int(days)) + + // "2011-11-11 10:10:20.000000" outputs "2011-11-11 10:10:20". + if t.Time.Nanosecond() == 0 { + t.Fsp = 0 + } + return t, nil } From 9f2bb45e3a3a8389cc521a2434d8fcd64f39f9e1 Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Wed, 14 Oct 2015 15:15:56 +0800 Subject: [PATCH 08/54] expression: add date_add function test. --- expression/date_add_test.go | 73 +++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 expression/date_add_test.go diff --git a/expression/date_add_test.go b/expression/date_add_test.go new file mode 100644 index 0000000000..c2acdad775 --- /dev/null +++ b/expression/date_add_test.go @@ -0,0 +1,73 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + . "github.com/pingcap/check" + mysql "github.com/pingcap/tidb/mysqldef" +) + +var _ = Suite(&testDateAddSuite{}) + +type testDateAddSuite struct { +} + +func (t *testDateAddSuite) TestExtract(c *C) { + input := "2011-11-11 10:10:10" + tbl := []struct { + Unit string + Interval interface{} + Expect string + }{ + {"MICROSECOND", "1000", "2011-11-11 10:10:10.001000"}, + {"MICROSECOND", 1000, "2011-11-11 10:10:10.001000"}, + {"SECOND", "10", "2011-11-11 10:10:20"}, + {"MINUTE", "10", "2011-11-11 10:20:10"}, + {"HOUR", "10", "2011-11-11 20:10:10"}, + {"DAY", "11", "2011-11-22 10:10:10"}, + {"WEEK", "2", "2011-11-25 10:10:10"}, + {"MONTH", "2", "2012-01-11 10:10:10"}, + {"QUARTER", "4", "2012-11-11 10:10:10"}, + {"YEAR", "2", "2013-11-11 10:10:10"}, + {"SECOND_MICROSECOND", "10.00100000", "2011-11-11 10:10:20.100000"}, + {"SECOND_MICROSECOND", "10.0010000000", "2011-11-11 10:10:30"}, + {"SECOND_MICROSECOND", "10.0010000010", "2011-11-11 10:10:30.000010"}, + {"MINUTE_MICROSECOND", "10:10.100", "2011-11-11 10:20:20.100000"}, + {"MINUTE_SECOND", "10:10", "2011-11-11 10:20:20"}, + {"HOUR_MICROSECOND", "10:10:10.100", "2011-11-11 20:20:20.100000"}, + {"HOUR_SECOND", "10:10:10", "2011-11-11 20:20:20"}, + {"HOUR_MINUTE", "10:10", "2011-11-11 20:20:10"}, + {"DAY_MICROSECOND", "11 10:10:10.100", "2011-11-22 20:20:20.100000"}, + {"DAY_SECOND", "11 10:10:10", "2011-11-22 20:20:20"}, + {"DAY_MINUTE", "11 10:10", "2011-11-22 20:20:10"}, + {"DAY_HOUR", "11 10", "2011-11-22 20:10:10"}, + {"YEAR_MONTH", "11-1", "2022-12-11 10:10:10"}, + {"YEAR_MONTH", "11-11", "2023-10-11 10:10:10"}, + } + + for _, t := range tbl { + e := &DateAdd{ + Unit: t.Unit, + Date: Value{Val: input}, + Interval: Value{Val: t.Interval}, + } + + v, err := e.Eval(nil, nil) + c.Assert(err, IsNil) + + value, ok := v.(mysql.Time) + c.Assert(ok, IsTrue) + c.Assert(value.String(), Equals, t.Expect) + } +} From 019ea8c77d202a7ac66ab7a8dc30d5aebf83c4c0 Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Wed, 14 Oct 2015 16:18:38 +0800 Subject: [PATCH 09/54] expression: add more test. --- expression/date_add_test.go | 71 +++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/expression/date_add_test.go b/expression/date_add_test.go index c2acdad775..8ad54bcfd9 100644 --- a/expression/date_add_test.go +++ b/expression/date_add_test.go @@ -25,6 +25,40 @@ type testDateAddSuite struct { func (t *testDateAddSuite) TestExtract(c *C) { input := "2011-11-11 10:10:10" + e := &DateAdd{ + Unit: "DAY", + Date: Value{Val: input}, + Interval: Value{Val: "1"}, + } + c.Assert(e.String(), Equals, `DATE_ADD("2011-11-11 10:10:10", INTERVAL "1" DAY)`) + c.Assert(e.Clone(), NotNil) + c.Assert(e.IsStatic(), IsTrue) + + _, err := e.Eval(nil, nil) + c.Assert(err, IsNil) + + // Test null. + e = &DateAdd{ + Unit: "DAY", + Date: Value{Val: nil}, + Interval: Value{Val: "1"}, + } + + v, err := e.Eval(nil, nil) + c.Assert(err, IsNil) + c.Assert(v, IsNil) + + e = &DateAdd{ + Unit: "DAY", + Date: Value{Val: input}, + Interval: Value{Val: nil}, + } + + v, err = e.Eval(nil, nil) + c.Assert(err, IsNil) + c.Assert(v, IsNil) + + // Test eval. tbl := []struct { Unit string Interval interface{} @@ -70,4 +104,41 @@ func (t *testDateAddSuite) TestExtract(c *C) { c.Assert(ok, IsTrue) c.Assert(value.String(), Equals, t.Expect) } + + // Test error. + errInput := "20111111 10:10:10" + errTbl := []struct { + Unit string + Interval interface{} + }{ + {"MICROSECOND", "abc1000"}, + {"MICROSECOND", ""}, + {"SECOND_MICROSECOND", "10"}, + {"MINUTE_MICROSECOND", "10.0000"}, + {"MINUTE_MICROSECOND", "10:10:10.0000"}, + + // MySQL support, but tidb not. + {"HOUR_MICROSECOND", "10:10.0000"}, + {"YEAR_MONTH", "10 1"}, + } + + for _, t := range errTbl { + e := &DateAdd{ + Unit: t.Unit, + Date: Value{Val: input}, + Interval: Value{Val: t.Interval}, + } + + _, err := e.Eval(nil, nil) + c.Assert(err, NotNil) + + e = &DateAdd{ + Unit: t.Unit, + Date: Value{Val: errInput}, + Interval: Value{Val: t.Interval}, + } + + v, err := e.Eval(nil, nil) + c.Assert(err, NotNil, Commentf("%s", v)) + } } From 1d7a14f35538d1bed0a112d299ce31bb5c8f2fb2 Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Thu, 15 Oct 2015 23:14:36 +0800 Subject: [PATCH 10/54] *: address comments. --- expression/date_add_test.go | 2 +- parser/parser.y | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/expression/date_add_test.go b/expression/date_add_test.go index 8ad54bcfd9..e13f3b2605 100644 --- a/expression/date_add_test.go +++ b/expression/date_add_test.go @@ -23,7 +23,7 @@ var _ = Suite(&testDateAddSuite{}) type testDateAddSuite struct { } -func (t *testDateAddSuite) TestExtract(c *C) { +func (t *testDateAddSuite) TestDateAdd(c *C) { input := "2011-11-11 10:10:10" e := &DateAdd{ Unit: "DAY", diff --git a/parser/parser.y b/parser/parser.y index 73d7d81e1c..758e334fcb 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -112,7 +112,7 @@ import ( currentUser "CURRENT_USER" database "DATABASE" databases "DATABASES" - dateAdd "DATE_ADD" + dateAdd "DATE_ADD" day "DAY" dayofmonth "DAYOFMONTH" dayofweek "DAYOFWEEK" From 8faca5e3f91a53155d10c8fd898d1878c1ae5534 Mon Sep 17 00:00:00 2001 From: siddontang Date: Fri, 16 Oct 2015 19:09:28 +0800 Subject: [PATCH 11/54] expression: add refer field for Ident --- expression/binop_test.go | 2 +- expression/helper.go | 4 +++- expression/helper_test.go | 6 +++--- expression/ident.go | 19 +++++++++++++++++++ expression/ident_test.go | 12 +++++++++++- 5 files changed, 37 insertions(+), 6 deletions(-) diff --git a/expression/binop_test.go b/expression/binop_test.go index 81f14ce804..37d4f8ba35 100644 --- a/expression/binop_test.go +++ b/expression/binop_test.go @@ -186,7 +186,7 @@ func (s *testBinOpSuite) TestIdentRelOp(c *C) { f := func(name string) *Ident { return &Ident{ - model.NewCIStr(name), + CIStr: model.NewCIStr(name), } } diff --git a/expression/helper.go b/expression/helper.go index be573e3fba..35a6d479ff 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -45,13 +45,15 @@ const ( ExprEvalPositionFunc = "$positionFunc" // ExprEvalValuesFunc is the key saving a function to retrieve value for column name. ExprEvalValuesFunc = "$valuesFunc" + // ExprEvalIdentReferFunc is the key saving a function to retrieve value with identifier reference index. + ExprEvalIdentReferFunc = "$identReferFunc" ) var ( // CurrentTimestamp is the keyword getting default value for datetime and timestamp type. CurrentTimestamp = "CURRENT_TIMESTAMP" // CurrentTimeExpr is the expression retireving default value for datetime and timestamp type. - CurrentTimeExpr = &Ident{model.NewCIStr(CurrentTimestamp)} + CurrentTimeExpr = &Ident{CIStr: model.NewCIStr(CurrentTimestamp)} // ZeroTimestamp shows the zero datetime and timestamp. ZeroTimestamp = "0000-00-00 00:00:00" ) diff --git a/expression/helper_test.go b/expression/helper_test.go index 6bda7bc2ac..5e656acc0b 100644 --- a/expression/helper_test.go +++ b/expression/helper_test.go @@ -57,7 +57,7 @@ func (s *testHelperSuite) TestMentionedColumns(c *C) { }{ {Value{1}, 0}, {&BinaryOperation{L: v, R: v}, 0}, - {&Ident{model.NewCIStr("id")}, 1}, + {&Ident{CIStr: model.NewCIStr("id")}, 1}, {&Call{F: "count", Args: []Expression{v}}, 0}, {&IsNull{Expr: v}, 0}, {&PExpr{Expr: v}, 0}, @@ -106,7 +106,7 @@ func (s *testHelperSuite) TestBase(c *C) { {int64(1), int64(1)}, {&UnaryOperation{Op: opcode.Plus, V: Value{1}}, 1}, {&UnaryOperation{Op: opcode.Not, V: Value{1}}, nil}, - {&UnaryOperation{Op: opcode.Plus, V: &Ident{model.NewCIStr("id")}}, nil}, + {&UnaryOperation{Op: opcode.Plus, V: &Ident{CIStr: model.NewCIStr("id")}}, nil}, {nil, nil}, } @@ -228,7 +228,7 @@ func (s *testHelperSuite) TestGetTimeValue(c *C) { {Value{"2012-13-12 00:00:00"}}, {Value{0}}, {Value{int64(1)}}, - {&Ident{model.NewCIStr("xxx")}}, + {&Ident{CIStr: model.NewCIStr("xxx")}}, {NewUnaryOperation(opcode.Minus, Value{int64(1)})}, } diff --git a/expression/ident.go b/expression/ident.go index a2ddff3d23..53881f6223 100644 --- a/expression/ident.go +++ b/expression/ident.go @@ -29,10 +29,23 @@ var ( _ Expression = (*Ident)(nil) ) +const ( + // IdentReferSelectList means the identifier reference is in select list. + IdentReferSelectList = 1 + // IdentReferFromTable means the identifier reference is in FROM table. + IdentReferFromTable = 2 +) + // Ident is the identifier expression. type Ident struct { // model.CIStr contains origin identifier name and its lowercase name. model.CIStr + + // ReferScope means where the identifer reference is, select list or from. + ReferScope int + + // ReferIndex is the index to get the identifer data. + ReferIndex int } // Clone implements the Expression Clone interface. @@ -63,6 +76,12 @@ func (i *Ident) Eval(ctx context.Context, args map[interface{}]interface{}) (v i return nil, nil } + if f, ok := args[ExprEvalIdentReferFunc]; ok { + if got, ok := f.(func(string, int, int) (interface{}, error)); ok && (i.ReferScope == IdentReferSelectList || i.ReferScope == IdentReferFromTable) { + return got(i.L, i.ReferScope, i.ReferIndex) + } + } + if f, ok := args[ExprEvalIdentFunc]; ok { if got, ok := f.(func(string) (interface{}, error)); ok { return got(i.L) diff --git a/expression/ident_test.go b/expression/ident_test.go index f71289be6c..08463d7eee 100644 --- a/expression/ident_test.go +++ b/expression/ident_test.go @@ -26,7 +26,7 @@ type testIdentSuite struct { func (s *testIdentSuite) TestIdent(c *C) { e := Ident{ - model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), } c.Assert(e.IsStatic(), IsFalse) @@ -55,4 +55,14 @@ func (s *testIdentSuite) TestIdent(c *C) { v, err = e.Eval(nil, m) c.Assert(err, IsNil) c.Assert(v, Equals, 1) + + delete(m, ExprEvalIdentFunc) + e.ReferScope = 1 + e.ReferIndex = 1 + m[ExprEvalIdentReferFunc] = func(string, int, int) (interface{}, error) { + return 2, nil + } + v, err = e.Eval(nil, m) + c.Assert(err, IsNil) + c.Assert(v, Equals, 2) } From 80a3c14e61f61e1873d4ea3f2f5738c0fae04274 Mon Sep 17 00:00:00 2001 From: siddontang Date: Fri, 16 Oct 2015 21:43:39 +0800 Subject: [PATCH 12/54] expression: aggregate function can't have agg in arg --- expression/helper.go | 31 +++++++++++++++++++++++++------ expression/helper_test.go | 9 +++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/expression/helper.go b/expression/helper.go index 35a6d479ff..a72071c803 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -149,22 +149,41 @@ func newMentionedAggregateFuncsVisitor() *mentionedAggregateFuncsVisitor { } func (v *mentionedAggregateFuncsVisitor) VisitCall(c *Call) (Expression, error) { + isAggregate, err := IsAggregateFunc(c.F) + if err != nil { + return nil, errors.Trace(err) + } + + if isAggregate { + v.exprs = append(v.exprs, c) + } + + // iaggregate function can't use aggregate function as the arg. + n := len(v.exprs) for _, e := range c.Args { _, err := e.Accept(v) if err != nil { return nil, errors.Trace(err) } } - f, ok := builtin.Funcs[strings.ToLower(c.F)] - if !ok { - return nil, errors.Errorf("unknown function %s", c.F) - } - if f.IsAggregate { - v.exprs = append(v.exprs, c) + + if len(v.exprs) != n { + // here means we have aggregate function in arg. + return nil, errors.Errorf("Invalid use of group function") } + return c, nil } +// IsAggregateFunc checks whether name is an aggregate function or not. +func IsAggregateFunc(name string) (bool, error) { + f, ok := builtin.Funcs[strings.ToLower(name)] + if !ok { + return false, errors.Errorf("unknown function %s", name) + } + return f.IsAggregate, nil +} + // MentionedColumns returns a list of names for Ident expression. func MentionedColumns(e Expression) []string { var names []string diff --git a/expression/helper_test.go b/expression/helper_test.go index 5e656acc0b..5d0141cef9 100644 --- a/expression/helper_test.go +++ b/expression/helper_test.go @@ -47,6 +47,15 @@ func (s *testHelperSuite) TestContainAggFunc(c *C) { b := ContainAggregateFunc(t.Expr) c.Assert(b, Equals, t.Expect) } + + expr := &Call{ + F: "count", + Args: []Expression{ + &Call{F: "count", Args: []Expression{v}}, + }, + } + _, err := MentionedAggregateFuncs(expr) + c.Assert(err, NotNil) } func (s *testHelperSuite) TestMentionedColumns(c *C) { From 20449b5d3fbe812ff65c81ff018ad960ea5e01f0 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 00:08:44 +0800 Subject: [PATCH 13/54] plans: return indices after checking ambiguous. --- plan/plans/select_list.go | 20 ++++++++++++-------- plan/plans/select_list_test.go | 12 ++++++------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/plan/plans/select_list.go b/plan/plans/select_list.go index 4560773227..26533b8618 100644 --- a/plan/plans/select_list.go +++ b/plan/plans/select_list.go @@ -150,24 +150,28 @@ func (s *SelectList) CloneHiddenField(name string, tableFields []*field.ResultFi // but "select c1 as a, c1 as a from t group by a" is not. // For MySQL "select c1 as a, c2 + 1 as a from t group by a" is not ambiguous too, // so we will only check identifier too. -// If no ambiguous, -1 means expr refers none in select list, else an index in select list returns. -func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error) { +// If no ambiguous, nil means expr refers none in select list, else an indices for fields have the same name in select list returns. +func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) ([]int, error) { if _, ok := expr.(*expression.Ident); !ok { - return -1, nil + return nil, nil } name := expr.String() if strings.Contains(name, ".") { // name is qualified, no need to check - return -1, nil + return nil, nil } lastIndex := -1 + var idx []int // only check origin select list, no hidden field. for i := 0; i < s.HiddenFieldOffset; i++ { if s.ResultFields[i].Name != name { continue - } else if _, ok := s.Fields[i].Expr.(*expression.Ident); !ok { + } + + idx = append(idx, i) + if _, ok := s.Fields[i].Expr.(*expression.Ident); !ok { // not identfier, no check continue } @@ -180,18 +184,18 @@ func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error // check origin name, e,g. "select c1 as c2, c2 from t group by c2" is ambiguous. if s.ResultFields[i].ColumnInfo.Name.O != s.ResultFields[lastIndex].ColumnInfo.Name.O { - return -1, errors.Errorf("refer %s is ambiguous", expr) + return nil, errors.Errorf("refer %s is ambiguous", expr) } // check table name, e.g, "select t.c1, c1 from t group by c1" is not ambiguous. if s.ResultFields[i].TableName != s.ResultFields[lastIndex].TableName { - return -1, errors.Errorf("refer %s is ambiguous", expr) + return nil, errors.Errorf("refer %s is ambiguous", expr) } // TODO: check database name if possible. } - return lastIndex, nil + return idx, nil } // ResolveSelectList gets fields and result fields from selectFields and srcFields, diff --git a/plan/plans/select_list_test.go b/plan/plans/select_list_test.go index 383a14e69b..ffef1a0981 100644 --- a/plan/plans/select_list_test.go +++ b/plan/plans/select_list_test.go @@ -82,11 +82,11 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) { Fields []pair Name string Err bool - Index int + Index []int }{ - {[]pair{{"id", ""}}, "id", false, 0}, - {[]pair{{"id", "a"}, {"name", "a"}}, "a", true, -1}, - {[]pair{{"id", "a"}, {"name", "a"}}, "id", false, -1}, + {[]pair{{"id", ""}}, "id", false, []int{0}}, + {[]pair{{"id", "a"}, {"name", "a"}}, "a", true, nil}, + {[]pair{{"id", "a"}, {"name", "a"}}, "id", false, nil}, } for _, t := range tbl { @@ -96,7 +96,7 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) { sl, err := plans.ResolveSelectList(fs, rs) c.Assert(err, IsNil) - index, err := sl.CheckReferAmbiguous(&expression.Ident{ + idx, err := sl.CheckReferAmbiguous(&expression.Ident{ CIStr: model.NewCIStr(t.Name), }) @@ -106,6 +106,6 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) { } c.Assert(err, IsNil) - c.Assert(t.Index, Equals, index) + c.Assert(t.Index, DeepEquals, idx) } } From 859f2d5ffcc510f40cdce6d1e6b0a8e13c7ed7e0 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 00:08:57 +0800 Subject: [PATCH 14/54] expression: update Ident evaluation. --- expression/ident.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/expression/ident.go b/expression/ident.go index 53881f6223..427355fe79 100644 --- a/expression/ident.go +++ b/expression/ident.go @@ -76,8 +76,10 @@ func (i *Ident) Eval(ctx context.Context, args map[interface{}]interface{}) (v i return nil, nil } + // TODO: we will unify ExprEvalIdentReferFunc and ExprEvalIdentFunc later, + // now just put them here for refactor step by step. if f, ok := args[ExprEvalIdentReferFunc]; ok { - if got, ok := f.(func(string, int, int) (interface{}, error)); ok && (i.ReferScope == IdentReferSelectList || i.ReferScope == IdentReferFromTable) { + if got, ok := f.(func(string, int, int) (interface{}, error)); ok { return got(i.L, i.ReferScope, i.ReferIndex) } } From 0d98d1b3e04d84a3a5b7afaa17c7a45f0e07d4bb Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 00:09:54 +0800 Subject: [PATCH 15/54] rsets: use visitor to fix Ident reference and check. --- rset/rsets/groupby.go | 147 +++++++++++++++++++++++++++++++++--------- 1 file changed, 117 insertions(+), 30 deletions(-) diff --git a/rset/rsets/groupby.go b/rset/rsets/groupby.go index aa3560bf97..c33de0cb8e 100644 --- a/rset/rsets/groupby.go +++ b/rset/rsets/groupby.go @@ -39,11 +39,124 @@ type GroupByRset struct { SelectList *plans.SelectList } +type groupByVisitor struct { + expression.BaseVisitor + selectList *plans.SelectList + rootIdent *expression.Ident +} + +func castIdent(e expression.Expression) *expression.Ident { + i, ok := e.(*expression.Ident) + if !ok { + return nil + } + return i +} + +func (v *groupByVisitor) checkIdent(i *expression.Ident) (int, error) { + idx, err := v.selectList.CheckReferAmbiguous(i) + if err != nil { + return -1, errors.Errorf("Column '%s' in group statement is ambiguous", i) + } else if len(idx) == 0 { + return -1, nil + } + + for _, index := range idx { + if _, ok := v.selectList.AggFields[index]; ok { + return -1, errors.Errorf("Reference '%s' not supported (reference to group function)", i) + } + } + + // this identifier may reference multi fields. + // e.g, select c1 as a, c2 + 1 as a from t group by a, + // we will use the first one which is not an identifer. + // so, for select c1 as a, c2 + 1 as a from t group by a, we will use c2 + 1. + for _, index := range idx { + if castIdent(v.selectList.Fields[index].Expr) == nil { + return index, nil + } + } + + return idx[0], nil +} + +func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { + // Group by ambiguous rule: + // select c1 as a, c2 as a from t group by a is ambiguous + // select c1 as a, c2 as a from t group by a + 1 is ambiguous + // select c1 as c2, c2 from t group by c2 is ambiguous + // select c1 as c2, c2 from t group by c2 + 1 is ambiguous + + var ( + index int + err error + ) + + if v.rootIdent == i { + // The group by is an identifier, we must check it first. + index, err = v.checkIdent(i) + if err != nil { + return nil, errors.Trace(err) + } + } + + // first find this identifier in FROM. + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + if len(idx) > 0 { + i.ReferScope = expression.IdentReferFromTable + i.ReferIndex = idx[0] + return i, nil + } + + if v.rootIdent != i { + // This identifier is the part of the group by, check ambiguous here. + index, err = v.checkIdent(i) + if err != nil { + return nil, errors.Trace(err) + } + } + + // try to find in select list, we have got index using checkIdent before. + if index >= 0 { + // find in select list + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, nil + } + + // TODO: check in out query + // TODO: return unknown field error, but now just return directly. + // Because this may reference outer query. + return i, nil +} + +func (v *groupByVisitor) VisitCall(c *expression.Call) (expression.Expression, error) { + ok, err := expression.IsAggregateFunc(c.F) + if err != nil { + return nil, errors.Trace(err) + } else if ok { + return nil, errors.Errorf("group by cannot contain aggregate function %s", c) + } + + for i, e := range c.Args { + c.Args[i], err = e.Accept(v) + if err != nil { + return nil, errors.Trace(err) + } + } + + return c, nil +} + // Plan gets GroupByDefaultPlan. func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) { fields := r.SelectList.Fields r.SelectList.AggFields = GetAggFields(fields) + visitor := &groupByVisitor{} + visitor.BaseVisitor.V = visitor + visitor.selectList = r.SelectList + aggFields := r.SelectList.AggFields for i, e := range r.By { @@ -70,38 +183,12 @@ func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) { // use Position expression for the associated field. r.By[i] = &expression.Position{N: position} } else { - index, err := r.SelectList.CheckReferAmbiguous(e) + visitor.rootIdent = castIdent(e) + by, err := e.Accept(visitor) if err != nil { - return nil, errors.Errorf("Column '%s' in group statement is ambiguous", e) - } else if _, ok := aggFields[index]; ok { - return nil, errors.Errorf("Can't group on '%s'", e) - } - - // TODO: check more ambiguous case - // Group by ambiguous rule: - // select c1 as a, c2 as a from t group by a is ambiguous - // select c1 as a, c2 as a from t group by a + 1 is ambiguous - // select c1 as c2, c2 from t group by c2 is ambiguous - // select c1 as c2, c2 from t group by c2 + 1 is ambiguous - - // TODO: use visitor to check aggregate function - names := expression.MentionedColumns(e) - for _, name := range names { - indices := field.GetFieldIndex(name, fields[0:r.SelectList.HiddenFieldOffset], field.DefaultFieldFlag) - if len(indices) == 1 { - // check reference to aggregate function, like `select c1, count(c1) as b from t group by b + 1`. - index := indices[0] - if _, ok := aggFields[index]; ok { - return nil, errors.Errorf("Reference '%s' not supported (reference to group function)", name) - } - } - } - - // group by should be an expression, a qualified field name or a select field position, - // but can not contain any aggregate function. - if e := r.By[i]; expression.ContainAggregateFunc(e) { - return nil, errors.Errorf("group by cannot contain aggregate function %s", e.String()) + return nil, errors.Trace(err) } + r.By[i] = by } } From 347b82a9aeaf271f7a902e8cfbcc48d0b054c471 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 00:10:29 +0800 Subject: [PATCH 16/54] plans: use ExprEvalIdentReferFunc to get group key data. --- plan/plans/groupby.go | 14 +++++--------- plan/plans/groupby_test.go | 4 +++- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/plan/plans/groupby.go b/plan/plans/groupby.go index ce81c002e6..a70ff3e692 100644 --- a/plan/plans/groupby.go +++ b/plan/plans/groupby.go @@ -89,15 +89,11 @@ type groupRow struct { func (r *GroupByDefaultPlan) evalGroupKey(ctx context.Context, k []interface{}, outRow []interface{}, in []interface{}) error { // group by items can not contain aggregate field, so we can eval them safely. m := map[interface{}]interface{}{} - m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) - if err == nil { - return v, nil - } - - v, err = r.getFieldValueByName(name, outRow) - if err == nil { - return v, nil + m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return in[index], nil + } else if scope == expression.IdentReferSelectList { + return outRow[index], nil } // try to find in outer query diff --git a/plan/plans/groupby_test.go b/plan/plans/groupby_test.go index 9fde4aea41..4e28b226c3 100644 --- a/plan/plans/groupby_test.go +++ b/plan/plans/groupby_test.go @@ -69,7 +69,9 @@ func (t *testGroupBySuite) TestGroupBy(c *C) { Src: tblPlan, By: []expression.Expression{ &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: expression.IdentReferFromTable, + ReferIndex: 0, }, }, } From 6344a412ca7939746c1a263a8c3f4bcc9a9e4a7c Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 09:32:44 +0800 Subject: [PATCH 17/54] rsets: add castPosition helper function and update --- rset/rsets/groupby.go | 55 +++++++++++++------------------------------ rset/rsets/helper.go | 46 ++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 39 deletions(-) diff --git a/rset/rsets/groupby.go b/rset/rsets/groupby.go index c33de0cb8e..19acf688e1 100644 --- a/rset/rsets/groupby.go +++ b/rset/rsets/groupby.go @@ -45,14 +45,6 @@ type groupByVisitor struct { rootIdent *expression.Ident } -func castIdent(e expression.Expression) *expression.Ident { - i, ok := e.(*expression.Ident) - if !ok { - return nil - } - return i -} - func (v *groupByVisitor) checkIdent(i *expression.Ident) (int, error) { idx, err := v.selectList.CheckReferAmbiguous(i) if err != nil { @@ -157,39 +149,24 @@ func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) { visitor.BaseVisitor.V = visitor visitor.selectList = r.SelectList - aggFields := r.SelectList.AggFields - for i, e := range r.By { - if v, ok := e.(expression.Value); ok { - var position int - switch u := v.Val.(type) { - case int64: - position = int(u) - case uint64: - position = int(u) - default: - continue - } - - if position < 1 || position > len(fields) { - return nil, errors.Errorf("Unknown column '%d' in 'group statement'", position) - } - - index := position - 1 - if _, ok := aggFields[index]; ok { - return nil, errors.Errorf("Can't group on '%s'", fields[index]) - } - - // use Position expression for the associated field. - r.By[i] = &expression.Position{N: position} - } else { - visitor.rootIdent = castIdent(e) - by, err := e.Accept(visitor) - if err != nil { - return nil, errors.Trace(err) - } - r.By[i] = by + pos, err := castPosition(e, r.SelectList, true) + if err != nil { + return nil, errors.Trace(err) } + + if pos != nil { + // use Position expression for the associated field. + r.By[i] = pos + continue + } + + visitor.rootIdent = castIdent(e) + by, err := e.Accept(visitor) + if err != nil { + return nil, errors.Trace(err) + } + r.By[i] = by } return &plans.GroupByDefaultPlan{By: r.By, Src: r.Src, diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index 53ce63b743..ac1abc638d 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -14,8 +14,10 @@ package rsets import ( + "github.com/juju/errors" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" + "github.com/pingcap/tidb/plan/plans" ) // GetAggFields gets aggregate fields position map. @@ -34,3 +36,47 @@ func HasAggFields(fields []*field.Field) bool { aggFields := GetAggFields(fields) return len(aggFields) > 0 } + +// castIdent returns an Ident expression if e is or nil. +func castIdent(e expression.Expression) *expression.Ident { + i, ok := e.(*expression.Ident) + if !ok { + return nil + } + return i +} + +// castPosition returns an group/order by Position expression if e is a number. +func castPosition(e expression.Expression, selectList *plans.SelectList, isGroupBy bool) (*expression.Position, error) { + v, ok := e.(expression.Value) + if !ok { + return nil, nil + } + + var position int + switch u := v.Val.(type) { + case int64: + position = int(u) + case uint64: + position = int(u) + default: + return nil, nil + } + + if position < 1 || position > selectList.HiddenFieldOffset { + if isGroupBy { + return nil, errors.Errorf("Unknown column '%d' in 'group statement'", position) + } + return nil, errors.Errorf("Unknown column '%d' in 'order clause'", position) + } + + if isGroupBy { + index := position - 1 + if _, ok := selectList.AggFields[index]; ok { + return nil, errors.Errorf("Can't group on '%s'", selectList.Fields[index]) + } + } + + // use Position expression for the associated field. + return &expression.Position{N: position}, nil +} From 7ba661531c1daa79f2d0f81de298eda695f53260 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 10:11:07 +0800 Subject: [PATCH 18/54] rsets: update helper function. --- rset/rsets/groupby.go | 33 ++-------------------- rset/rsets/helper.go | 64 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 61 insertions(+), 36 deletions(-) diff --git a/rset/rsets/groupby.go b/rset/rsets/groupby.go index 19acf688e1..008a32fb71 100644 --- a/rset/rsets/groupby.go +++ b/rset/rsets/groupby.go @@ -45,33 +45,6 @@ type groupByVisitor struct { rootIdent *expression.Ident } -func (v *groupByVisitor) checkIdent(i *expression.Ident) (int, error) { - idx, err := v.selectList.CheckReferAmbiguous(i) - if err != nil { - return -1, errors.Errorf("Column '%s' in group statement is ambiguous", i) - } else if len(idx) == 0 { - return -1, nil - } - - for _, index := range idx { - if _, ok := v.selectList.AggFields[index]; ok { - return -1, errors.Errorf("Reference '%s' not supported (reference to group function)", i) - } - } - - // this identifier may reference multi fields. - // e.g, select c1 as a, c2 + 1 as a from t group by a, - // we will use the first one which is not an identifer. - // so, for select c1 as a, c2 + 1 as a from t group by a, we will use c2 + 1. - for _, index := range idx { - if castIdent(v.selectList.Fields[index].Expr) == nil { - return index, nil - } - } - - return idx[0], nil -} - func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { // Group by ambiguous rule: // select c1 as a, c2 as a from t group by a is ambiguous @@ -86,7 +59,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, if v.rootIdent == i { // The group by is an identifier, we must check it first. - index, err = v.checkIdent(i) + index, err = checkIdent(i, v.selectList, groupByClause) if err != nil { return nil, errors.Trace(err) } @@ -102,7 +75,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, if v.rootIdent != i { // This identifier is the part of the group by, check ambiguous here. - index, err = v.checkIdent(i) + index, err = checkIdent(i, v.selectList, groupByClause) if err != nil { return nil, errors.Trace(err) } @@ -150,7 +123,7 @@ func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) { visitor.selectList = r.SelectList for i, e := range r.By { - pos, err := castPosition(e, r.SelectList, true) + pos, err := castPosition(e, r.SelectList, groupByClause) if err != nil { return nil, errors.Trace(err) } diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index ac1abc638d..0a73c15247 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -46,8 +46,29 @@ func castIdent(e expression.Expression) *expression.Ident { return i } +type clauseType int + +const ( + noneClause clauseType = iota + groupByClause + orderByClause + havingClause +) + +func (clause clauseType) String() string { + switch clause { + case groupByClause: + return "group statement" + case orderByClause: + return "order clause" + case havingClause: + return "having clause" + } + return "none" +} + // castPosition returns an group/order by Position expression if e is a number. -func castPosition(e expression.Expression, selectList *plans.SelectList, isGroupBy bool) (*expression.Position, error) { +func castPosition(e expression.Expression, selectList *plans.SelectList, clause clauseType) (*expression.Position, error) { v, ok := e.(expression.Value) if !ok { return nil, nil @@ -64,13 +85,10 @@ func castPosition(e expression.Expression, selectList *plans.SelectList, isGroup } if position < 1 || position > selectList.HiddenFieldOffset { - if isGroupBy { - return nil, errors.Errorf("Unknown column '%d' in 'group statement'", position) - } - return nil, errors.Errorf("Unknown column '%d' in 'order clause'", position) + return nil, errors.Errorf("Unknown column '%d' in '%s'", position, clause) } - if isGroupBy { + if clause == groupByClause { index := position - 1 if _, ok := selectList.AggFields[index]; ok { return nil, errors.Errorf("Can't group on '%s'", selectList.Fields[index]) @@ -80,3 +98,37 @@ func castPosition(e expression.Expression, selectList *plans.SelectList, isGroup // use Position expression for the associated field. return &expression.Position{N: position}, nil } + +func checkIdent(i *expression.Ident, selectList *plans.SelectList, clause clauseType) (int, error) { + idx, err := selectList.CheckReferAmbiguous(i) + if err != nil { + return -1, errors.Errorf("Column '%s' in %s is ambiguous", i, clause) + } else if len(idx) == 0 { + return -1, nil + } + + // this identifier may reference multi fields. + // e.g, select c1 as a, c2 + 1 as a from t group by a, + // we will use the first one which is not an identifer. + // so, for select c1 as a, c2 + 1 as a from t group by a, we will use c2 + 1. + + useIndex := 0 + found := false + for _, index := range idx { + if clause == groupByClause { + // group by can not reference aggregate fields + if _, ok := selectList.AggFields[index]; ok { + return -1, errors.Errorf("Reference '%s' not supported (reference to group function)", i) + } + } + + if !found { + if castIdent(selectList.Fields[index].Expr) == nil { + useIndex = index + found = true + } + } + } + + return idx[useIndex], nil +} From 16c11ea4d03e4c1f64a09bac072a66766d7383ec Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 10:47:37 +0800 Subject: [PATCH 19/54] rsets: fix select fields ident reference. --- plan/plans/fields.go | 7 +++---- rset/rsets/fields.go | 20 ++++++++++++++++++++ rset/rsets/fields_test.go | 3 +++ rset/rsets/helper.go | 27 +++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/plan/plans/fields.go b/plan/plans/fields.go index c6198ab81e..6624f8c6a1 100644 --- a/plan/plans/fields.go +++ b/plan/plans/fields.go @@ -64,10 +64,9 @@ func (r *SelectFieldsDefaultPlan) Next(ctx context.Context) (row *plan.Row, err return nil, errors.Trace(err) } - r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err0 := GetIdentValue(name, r.Src.GetFields(), srcRow.Data, field.DefaultFieldFlag) - if err0 == nil { - return v, nil + r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return srcRow.FromData[index], nil } return getIdentValueFromOuterQuery(ctx, name) diff --git a/rset/rsets/fields.go b/rset/rsets/fields.go index 091519fa59..624b7f26e2 100644 --- a/rset/rsets/fields.go +++ b/rset/rsets/fields.go @@ -39,10 +39,29 @@ type SelectFieldsRset struct { SelectList *plans.SelectList } +func fixSelectFieldsRefer(selectList *plans.SelectList) error { + visitor := newFromIdentVisitor(selectList.FromFields) + + // we only fix un-hidden fields here, for hidden fields, it should be + // handled in their own clause, in order by or having. + for i, f := range selectList.Fields[0:selectList.HiddenFieldOffset] { + e, err := f.Expr.Accept(visitor) + if err != nil { + return errors.Trace(err) + } + selectList.Fields[i].Expr = e + } + return nil +} + // Plan gets SrcPlan/SelectFieldsDefaultPlan. // If all fields are equal to src plan fields, then gets SrcPlan. // Default gets SelectFieldsDefaultPlan. func (r *SelectFieldsRset) Plan(ctx context.Context) (plan.Plan, error) { + if err := fixSelectFieldsRefer(r.SelectList); err != nil { + return nil, errors.Trace(err) + } + fields := r.SelectList.Fields srcFields := r.Src.GetFields() if len(fields) == len(srcFields) { @@ -97,6 +116,7 @@ func (r *SelectFromDualRset) Plan(ctx context.Context) (plan.Plan, error) { // field cannot contain identifier for _, f := range r.Fields { if cs := expression.MentionedColumns(f.Expr); len(cs) > 0 { + // TODO: check in outer query, like select * from t where t.c = (select c limit 1); return nil, errors.Errorf("Unknown column '%s' in 'field list'", cs[0]) } } diff --git a/rset/rsets/fields_test.go b/rset/rsets/fields_test.go index 36ad26f297..aa63daa5b4 100644 --- a/rset/rsets/fields_test.go +++ b/rset/rsets/fields_test.go @@ -69,6 +69,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) { oldFld := s.sr.SelectList.Fields[1] s.sr.SelectList.Fields = s.sr.SelectList.Fields[:1] + s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields) p, err = s.sr.Plan(nil) c.Assert(err, IsNil) @@ -81,6 +82,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) { s.sr.Src = tdp s.sr.SelectList.Fields = []*field.Field{fld} + s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields) p, err = s.sr.Plan(nil) c.Assert(err, IsNil) @@ -94,6 +96,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) { // cover isConst check, like `select 1, c1 from t` s.sr.SelectList.Fields = []*field.Field{fld, oldFld} + s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields) p, err = s.sr.Plan(nil) c.Assert(err, IsNil) diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index 0a73c15247..1461e2f700 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -132,3 +132,30 @@ func checkIdent(i *expression.Ident, selectList *plans.SelectList, clause clause return idx[useIndex], nil } + +// fromIdentVisitor can only handle identifier which reference FROM table or outer query. +// like in common select list, where or join on condition. +type fromIdentVisitor struct { + expression.BaseVisitor + fromFields []*field.ResultField +} + +func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { + idx := field.GetResultFieldIndex(i.L, v.fromFields, field.DefaultFieldFlag) + if len(idx) > 0 { + i.ReferScope = expression.IdentReferFromTable + i.ReferIndex = idx[0] + return i, nil + } + + // TODO: check in outer query + return i, nil +} + +func newFromIdentVisitor(fromFields []*field.ResultField) *fromIdentVisitor { + visitor := &fromIdentVisitor{} + visitor.BaseVisitor.V = visitor + visitor.fromFields = fromFields + + return visitor +} From 894a009f40a085c8b6866483681227858b1a9c81 Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Sat, 17 Oct 2015 11:05:20 +0800 Subject: [PATCH 20/54] expression: tiny clean up. --- expression/helper.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/expression/helper.go b/expression/helper.go index a72071c803..60b31a7c4f 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -153,12 +153,11 @@ func (v *mentionedAggregateFuncsVisitor) VisitCall(c *Call) (Expression, error) if err != nil { return nil, errors.Trace(err) } - if isAggregate { v.exprs = append(v.exprs, c) } - // iaggregate function can't use aggregate function as the arg. + // aggregate function can't use aggregate function as the arg. n := len(v.exprs) for _, e := range c.Args { _, err := e.Accept(v) From 8a3dc764d57a51fcbb7e894fe670308ffee4db39 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 11:46:56 +0800 Subject: [PATCH 21/54] rsets: update function name. --- rset/rsets/fields.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rset/rsets/fields.go b/rset/rsets/fields.go index 624b7f26e2..4df4726463 100644 --- a/rset/rsets/fields.go +++ b/rset/rsets/fields.go @@ -39,7 +39,7 @@ type SelectFieldsRset struct { SelectList *plans.SelectList } -func fixSelectFieldsRefer(selectList *plans.SelectList) error { +func updateSelectFieldsRefer(selectList *plans.SelectList) error { visitor := newFromIdentVisitor(selectList.FromFields) // we only fix un-hidden fields here, for hidden fields, it should be @@ -58,7 +58,7 @@ func fixSelectFieldsRefer(selectList *plans.SelectList) error { // If all fields are equal to src plan fields, then gets SrcPlan. // Default gets SelectFieldsDefaultPlan. func (r *SelectFieldsRset) Plan(ctx context.Context) (plan.Plan, error) { - if err := fixSelectFieldsRefer(r.SelectList); err != nil { + if err := updateSelectFieldsRefer(r.SelectList); err != nil { return nil, errors.Trace(err) } From 51e040c537363c6d90ee6c83f3d83a6e74c270ed Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Sat, 17 Oct 2015 11:50:09 +0800 Subject: [PATCH 22/54] *: use visitor to update where field eval. --- plan/plans/where.go | 8 ++++---- plan/plans/where_test.go | 4 +++- rset/rsets/where.go | 19 +++++++++++++++++-- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/plan/plans/where.go b/plan/plans/where.go index 5f9bf59eeb..2a1a866ff1 100644 --- a/plan/plans/where.go +++ b/plan/plans/where.go @@ -56,15 +56,15 @@ func (r *FilterDefaultPlan) Next(ctx context.Context) (row *plan.Row, err error) if row == nil || err != nil { return nil, errors.Trace(err) } - r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err0 := GetIdentValue(name, r.GetFields(), row.Data, field.DefaultFieldFlag) - if err0 == nil { - return v, nil + r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return row.Data[index], nil } // try to find in outer query return getIdentValueFromOuterQuery(ctx, name) } + var meet bool meet, err = r.meetCondition(ctx) if err != nil { diff --git a/plan/plans/where_test.go b/plan/plans/where_test.go index 6c890608aa..afb5c0e863 100644 --- a/plan/plans/where_test.go +++ b/plan/plans/where_test.go @@ -44,7 +44,9 @@ func (t *testWhereSuit) TestWhere(c *C) { Expr: &expression.BinaryOperation{ Op: opcode.GE, L: &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: 2, + ReferIndex: 0, }, R: expression.Value{ Val: 30, diff --git a/rset/rsets/where.go b/rset/rsets/where.go index 30729792a9..f32313ac8f 100644 --- a/rset/rsets/where.go +++ b/rset/rsets/where.go @@ -162,7 +162,6 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl if err != nil { return nil, err } - if val == nil { // like `select * from t where null`. return &plans.NullPlan{Fields: r.Src.GetFields()}, nil @@ -172,7 +171,6 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl if err != nil { return nil, err } - if n == 0 { // like `select * from t where 0`. return &plans.NullPlan{Fields: r.Src.GetFields()}, nil @@ -181,8 +179,25 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl return &plans.FilterDefaultPlan{Plan: r.Src, Expr: e}, nil } +func (r *WhereRset) updateWhereFieldsRefer() error { + visitor := newFromIdentVisitor(r.Src.GetFields()) + + e, err := r.Expr.Accept(visitor) + if err != nil { + return errors.Trace(err) + } + + r.Expr = e + return nil +} + // Plan gets NullPlan/FilterDefaultPlan. func (r *WhereRset) Plan(ctx context.Context) (plan.Plan, error) { + // Update where fields refer expr by using fromIdentVisitor. + if err := r.updateWhereFieldsRefer(); err != nil { + return nil, errors.Trace(err) + } + expr := r.Expr.Clone() if expr.IsStatic() { // IsStaic means we have a const value for where condition, and we don't need any index. From 2a76b6cecd56f855734b5883617ddafd6d9706a9 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 11:46:56 +0800 Subject: [PATCH 23/54] rsets: update function name. --- rset/rsets/fields.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rset/rsets/fields.go b/rset/rsets/fields.go index 624b7f26e2..4df4726463 100644 --- a/rset/rsets/fields.go +++ b/rset/rsets/fields.go @@ -39,7 +39,7 @@ type SelectFieldsRset struct { SelectList *plans.SelectList } -func fixSelectFieldsRefer(selectList *plans.SelectList) error { +func updateSelectFieldsRefer(selectList *plans.SelectList) error { visitor := newFromIdentVisitor(selectList.FromFields) // we only fix un-hidden fields here, for hidden fields, it should be @@ -58,7 +58,7 @@ func fixSelectFieldsRefer(selectList *plans.SelectList) error { // If all fields are equal to src plan fields, then gets SrcPlan. // Default gets SelectFieldsDefaultPlan. func (r *SelectFieldsRset) Plan(ctx context.Context) (plan.Plan, error) { - if err := fixSelectFieldsRefer(r.SelectList); err != nil { + if err := updateSelectFieldsRefer(r.SelectList); err != nil { return nil, errors.Trace(err) } From e280bc9b752d477c8a35d8cf553c75d03a3cec57 Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Sat, 17 Oct 2015 12:23:19 +0800 Subject: [PATCH 24/54] plan: address comment. --- plan/plans/where_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plan/plans/where_test.go b/plan/plans/where_test.go index afb5c0e863..f1effcfbc6 100644 --- a/plan/plans/where_test.go +++ b/plan/plans/where_test.go @@ -45,7 +45,7 @@ func (t *testWhereSuit) TestWhere(c *C) { Op: opcode.GE, L: &expression.Ident{ CIStr: model.NewCIStr("id"), - ReferScope: 2, + ReferScope: expression.IdentReferFromTable, ReferIndex: 0, }, R: expression.Value{ From 9e73c8d4a830b8386b55622f8ae3ca530ad941d8 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 14:20:19 +0800 Subject: [PATCH 25/54] *: use Ident refer for order by. --- plan/plans/groupby.go | 29 ++++++- plan/plans/groupby_test.go | 12 ++- plan/plans/orderby.go | 9 +- plan/plans/orderby_test.go | 4 +- plan/plans/select_list.go | 45 +++++----- rset/rsets/groupby.go | 7 +- rset/rsets/groupby_test.go | 9 ++ rset/rsets/having.go | 17 +++- rset/rsets/having_test.go | 53 ------------ rset/rsets/helper.go | 2 +- rset/rsets/orderby.go | 169 ++++++++++++++++++++++++------------- rset/rsets/orderby_test.go | 64 -------------- stmt/stmts/select.go | 2 +- stmt/stmts/union.go | 8 +- tidb_test.go | 13 ++- 15 files changed, 218 insertions(+), 225 deletions(-) diff --git a/plan/plans/groupby.go b/plan/plans/groupby.go index a70ff3e692..d3af0b857b 100644 --- a/plan/plans/groupby.go +++ b/plan/plans/groupby.go @@ -128,7 +128,14 @@ func (r *GroupByDefaultPlan) getFieldValueByName(name string, out []interface{}) func (r *GroupByDefaultPlan) evalNoneAggFields(ctx context.Context, out []interface{}, m map[interface{}]interface{}, in []interface{}) error { - m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { + m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return in[index], nil + } else if scope == expression.IdentReferSelectList { + return out[index], nil + } + + // TODO: remove following getting later. v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) if err == nil { return v, nil @@ -153,10 +160,18 @@ func (r *GroupByDefaultPlan) evalNoneAggFields(ctx context.Context, out []interf func (r *GroupByDefaultPlan) evalAggFields(ctx context.Context, out []interface{}, m map[interface{}]interface{}, in []interface{}) error { + // Eval aggregate field results in ctx for i := range r.AggFields { if i < r.HiddenFieldOffset { - m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { + m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return in[index], nil + } else if scope == expression.IdentReferSelectList { + return out[index], nil + } + + //TODO: remove following getting later v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) if err == nil { return v, nil @@ -171,12 +186,20 @@ func (r *GroupByDefaultPlan) evalAggFields(ctx context.Context, out []interface{ // select c1 as a from t having count(a) = 1 // because all the select list data is generated before, so we can get it // when handling hidden field. - m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { + m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return in[index], nil + } else if scope == expression.IdentReferSelectList { + return out[index], nil + } + + // TODO: remove later v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) if err == nil { return v, nil } + // TODO: remove later // if we can not find in table, we will try to find in un-hidden select list // only hidden field can use this v, err = r.getFieldValueByName(name, out) diff --git a/plan/plans/groupby_test.go b/plan/plans/groupby_test.go index 4e28b226c3..9c3caab894 100644 --- a/plan/plans/groupby_test.go +++ b/plan/plans/groupby_test.go @@ -42,12 +42,16 @@ func (t *testGroupBySuite) TestGroupBy(c *C) { Fields: []*field.Field{ { Expr: &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: expression.IdentReferFromTable, + ReferIndex: 0, }, }, { Expr: &expression.Ident{ - CIStr: model.NewCIStr("name"), + CIStr: model.NewCIStr("name"), + ReferScope: expression.IdentReferFromTable, + ReferIndex: 1, }, }, { @@ -55,7 +59,9 @@ func (t *testGroupBySuite) TestGroupBy(c *C) { F: "sum", Args: []expression.Expression{ &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: expression.IdentReferFromTable, + ReferIndex: 0, }, }, }, diff --git a/plan/plans/orderby.go b/plan/plans/orderby.go index 347f536e48..4708339203 100644 --- a/plan/plans/orderby.go +++ b/plan/plans/orderby.go @@ -158,10 +158,11 @@ func (r *OrderByDefaultPlan) fetchAll(ctx context.Context) error { if row == nil { break } - evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err := GetIdentValue(name, r.ResultFields, row.Data, field.CheckFieldFlag) - if err == nil { - return v, nil + evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return row.FromData[index], nil + } else if scope == expression.IdentReferSelectList { + return row.Data[index], nil } // try to find in outer query diff --git a/plan/plans/orderby_test.go b/plan/plans/orderby_test.go index 2b434300ad..828081016d 100644 --- a/plan/plans/orderby_test.go +++ b/plan/plans/orderby_test.go @@ -48,7 +48,9 @@ func (t *testOrderBySuit) TestOrderBy(c *C) { Src: tblPlan, By: []expression.Expression{ &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: expression.IdentReferSelectList, + ReferIndex: 0, }, }, Ascs: []bool{false}, diff --git a/plan/plans/select_list.go b/plan/plans/select_list.go index a77fad8ec5..bacc7b729d 100644 --- a/plan/plans/select_list.go +++ b/plan/plans/select_list.go @@ -95,36 +95,33 @@ func (s *SelectList) GetFields() []*field.ResultField { } // UpdateAggFields adds aggregate function resultfield to select result field list. -func (s *SelectList) UpdateAggFields(expr expression.Expression, tableFields []*field.ResultField) (expression.Expression, error) { - // For aggregate function, the name can be in table or select list. - names := expression.MentionedColumns(expr) - - for _, name := range names { - if field.ContainFieldName(name, tableFields, field.DefaultFieldFlag) { - continue - } - - if field.ContainFieldName(name, s.ResultFields, field.DefaultFieldFlag) { - continue - } - - return nil, errors.Errorf("Unknown column '%s'", name) - } - +func (s *SelectList) UpdateAggFields(expr expression.Expression) (expression.Expression, error) { // We must add aggregate function to hidden select list // and use a position expression to fetch its value later. - exprName := expr.String() - idx := field.GetResultFieldIndex(exprName, s.ResultFields, field.CheckFieldFlag) - if len(idx) == 0 { - f := &field.Field{Expr: expr} - resultField := &field.ResultField{Name: exprName} - s.AddField(f, resultField) + name := strings.ToLower(expr.String()) + index := -1 + for i := 0; i < s.HiddenFieldOffset; i++ { + // only check origin name, e,g. "select sum(c1) as a from t order by sum(c1)" + // or "select sum(c1) from t order by sum(c1)" + if s.ResultFields[i].ColumnInfo.Name.L == name { + index = i + break + } + } - return &expression.Position{N: len(s.Fields), Name: exprName}, nil + if index == -1 { + f := &field.Field{Expr: expr} + s.AddField(f, nil) + + pos := len(s.Fields) + + s.AggFields[pos-1] = struct{}{} + + return &expression.Position{N: pos, Name: name}, nil } // select list has this field, use it directly. - return &expression.Position{N: idx[0] + 1, Name: exprName}, nil + return &expression.Position{N: index + 1, Name: name}, nil } // CloneHiddenField checks and clones field and result field from table fields, diff --git a/rset/rsets/groupby.go b/rset/rsets/groupby.go index 008a32fb71..176eee5e93 100644 --- a/rset/rsets/groupby.go +++ b/rset/rsets/groupby.go @@ -50,7 +50,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, // select c1 as a, c2 as a from t group by a is ambiguous // select c1 as a, c2 as a from t group by a + 1 is ambiguous // select c1 as c2, c2 from t group by c2 is ambiguous - // select c1 as c2, c2 from t group by c2 + 1 is ambiguous + // select c1 as c2, c2 from t group by c2 + 1 is not ambiguous var ( index int @@ -115,9 +115,10 @@ func (v *groupByVisitor) VisitCall(c *expression.Call) (expression.Expression, e // Plan gets GroupByDefaultPlan. func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) { - fields := r.SelectList.Fields + if err := updateSelectFieldsRefer(r.SelectList); err != nil { + return nil, errors.Trace(err) + } - r.SelectList.AggFields = GetAggFields(fields) visitor := &groupByVisitor{} visitor.BaseVisitor.V = visitor visitor.selectList = r.SelectList diff --git a/rset/rsets/groupby_test.go b/rset/rsets/groupby_test.go index 9a87172316..eebe20d9cb 100644 --- a/rset/rsets/groupby_test.go +++ b/rset/rsets/groupby_test.go @@ -50,6 +50,11 @@ func (s *testGroupByRsetSuite) SetUpSuite(c *C) { s.r = &GroupByRset{Src: tblPlan, SelectList: selectList, By: by} } +func resetAggFields(selectList *plans.SelectList) { + fields := selectList.Fields + selectList.AggFields = GetAggFields(fields) +} + func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) { // `select id, name from t group by name` p, err := s.r.Plan(nil) @@ -88,6 +93,8 @@ func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) { fld := &field.Field{Expr: fldExpr, AsName: "a"} s.r.SelectList.Fields[0] = fld + resetAggFields(s.r.SelectList) + s.r.By[0] = expression.Value{Val: int64(1)} _, err = s.r.Plan(nil) @@ -115,6 +122,8 @@ func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) { s.r.SelectList.ResultFields[0].Col.Name = model.NewCIStr("count(id)") s.r.SelectList.ResultFields[0].Name = "a" + resetAggFields(s.r.SelectList) + _, err = s.r.Plan(nil) c.Assert(err, NotNil) diff --git a/rset/rsets/having.go b/rset/rsets/having.go index 39ba3cc561..4a3ad02080 100644 --- a/rset/rsets/having.go +++ b/rset/rsets/having.go @@ -33,10 +33,25 @@ type HavingRset struct { SelectList *plans.SelectList } +// CheckAggregate will check whether order by has aggregate function or not, +// if has, we will add it to select list hidden field. +func (r *HavingRset) CheckAggregate(selectList *plans.SelectList) error { + if expression.ContainAggregateFunc(r.Expr) { + expr, err := selectList.UpdateAggFields(r.Expr) + if err != nil { + return errors.Errorf("%s in 'having clause'", err.Error()) + } + + r.Expr = expr + } + + return nil +} + // CheckAndUpdateSelectList checks having fields validity and set hidden fields to selectList. func (r *HavingRset) CheckAndUpdateSelectList(selectList *plans.SelectList, groupBy []expression.Expression, tableFields []*field.ResultField) error { if expression.ContainAggregateFunc(r.Expr) { - expr, err := selectList.UpdateAggFields(r.Expr, tableFields) + expr, err := selectList.UpdateAggFields(r.Expr) if err != nil { return errors.Errorf("%s in 'having clause'", err.Error()) } diff --git a/rset/rsets/having_test.go b/rset/rsets/having_test.go index 91a13e3052..180f57d779 100644 --- a/rset/rsets/having_test.go +++ b/rset/rsets/having_test.go @@ -50,59 +50,6 @@ func (s *testHavingRsetSuite) SetUpSuite(c *C) { s.r = &HavingRset{Src: tblPlan, Expr: expr, SelectList: selectList} } -func (s *testHavingRsetSuite) TestHavingRsetCheckAndUpdateSelectList(c *C) { - resultFields := s.r.Src.GetFields() - - selectList := s.r.SelectList - - groupBy := []expression.Expression{} - - // `select id, name from t having id > 1` - err := s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, IsNil) - - // `select name from t group by id having id > 1` - selectList.ResultFields = selectList.ResultFields[1:] - selectList.Fields = selectList.Fields[1:] - - groupBy = []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("id")}} - err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, IsNil) - - // `select name from t group by id + 1 having id > 1` - expr := expression.NewBinaryOperation(opcode.Plus, &expression.Ident{CIStr: model.NewCIStr("id")}, expression.Value{Val: 1}) - - groupBy = []expression.Expression{expr} - err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, IsNil) - - // `select name from t group by id + 1 having count(1) > 1` - aggExpr, err := expression.NewCall("count", []expression.Expression{expression.Value{Val: 1}}, false) - c.Assert(err, IsNil) - - s.r.Expr = aggExpr - - err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, IsNil) - - // `select name from t group by id + 1 having count(xxx) > 1` - aggExpr, err = expression.NewCall("count", []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("xxx")}}, false) - c.Assert(err, IsNil) - - s.r.Expr = aggExpr - - err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, NotNil) - - // `select name from t group by id having xxx > 1` - expr = expression.NewBinaryOperation(opcode.GT, &expression.Ident{CIStr: model.NewCIStr("xxx")}, expression.Value{Val: 1}) - - s.r.Expr = expr - - err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, NotNil) -} - func (s *testHavingRsetSuite) TestHavingRsetPlan(c *C) { p, err := s.r.Plan(nil) c.Assert(err, IsNil) diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index 1461e2f700..abea6bb44f 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -22,7 +22,7 @@ import ( // GetAggFields gets aggregate fields position map. func GetAggFields(fields []*field.Field) map[int]struct{} { - aggFields := map[int]struct{}{} + aggFields := make(map[int]struct{}, len(fields)) for i, v := range fields { if expression.ContainAggregateFunc(v.Expr) { aggFields[i] = struct{}{} diff --git a/rset/rsets/orderby.go b/rset/rsets/orderby.go index e0b9e74834..c16daf5ba4 100644 --- a/rset/rsets/orderby.go +++ b/rset/rsets/orderby.go @@ -65,46 +65,112 @@ func (r *OrderByRset) String() string { return strings.Join(a, ", ") } -// CheckAndUpdateSelectList checks order by fields validity and set hidden fields to selectList. -func (r *OrderByRset) CheckAndUpdateSelectList(selectList *plans.SelectList, tableFields []*field.ResultField) error { +// CheckAggregate will check whether order by has aggregate function or not, +// if has, we will add it to select list hidden field. +func (r *OrderByRset) CheckAggregate(selectList *plans.SelectList) error { for i, v := range r.By { if expression.ContainAggregateFunc(v.Expr) { - expr, err := selectList.UpdateAggFields(v.Expr, tableFields) + expr, err := selectList.UpdateAggFields(v.Expr) if err != nil { return errors.Errorf("%s in 'order clause'", err.Error()) } r.By[i].Expr = expr - } else { - if _, err := selectList.CheckReferAmbiguous(v.Expr); err != nil { - return errors.Errorf("Column '%s' in order statement is ambiguous", v.Expr) - } + } + } + return nil +} - // TODO: check more ambiguous case - // Order by ambiguous rule: - // select c1 as a, c2 as a from t order by a is ambiguous - // select c1 as a, c2 as a from t order by a + 1 is ambiguous - // select c1 as c2, c2 from t order by c2 is ambiguous - // select c1 as c2, c2 from t order by c2 + 1 is ambiguous +type orderByVisitor struct { + expression.BaseVisitor + selectList *plans.SelectList + rootIdent *expression.Ident +} - // TODO: use vistor to refactor all and combine following plan check. - names := expression.MentionedColumns(v.Expr) - for _, name := range names { - // try to find in select list - // TODO: mysql has confused result for this, see #555. - // now we use select list then order by, later we should make it easier. - if field.ContainFieldName(name, selectList.ResultFields, field.CheckFieldFlag) { - continue - } +func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { + // Order by ambiguous rule: + // select c1 as a, c2 as a from t order by a is ambiguous + // select c1 as a, c2 as a from t order by a + 1 is ambiguous + // select c1 as c2, c2 from t order by c2 is ambiguous + // select c1 as c2, c2 from t order by c2 + 1 is not ambiguous - if !selectList.CloneHiddenField(name, tableFields) { - return errors.Errorf("Unknown column '%s' in 'order clause'", name) - } - } + // Order by identifier reference check + // select c1 as c2 from t order by c2, c2 in order by references c1 in select field. + // select c1 as c2 from t order by c2 + 1, c2 in order by c2 + 1 references t.c2 in from table. + // select c1 as c2 from t order by sum(c2), c2 in order by sum(c2) references t.c2 in from table. + // select c1 as c2 from t order by sum(1) + c2, c2 in order by sum(1) + c2 references t.c2 in from table. + // select c1 as a from t order by a, a references c1 in select field. + // select c1 as a from t order by sum(a), a in order by sum(a) references c1 in select field. + + // TODO: unify VisitIdent for group by and order by visitor, they are little different. + + var ( + index int + err error + ) + + if v.rootIdent == i { + // The order by is an identifier, we must check it first. + index, err = checkIdent(i, v.selectList, orderByClause) + if err != nil { + return nil, errors.Trace(err) + } + + if index >= 0 { + // identifier references a select field. use it directly. + // e,g. select c1 as c2 from t order by c2, here c2 references c1. + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, nil } } - return nil + // find this identifier in FROM. + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + if len(idx) > 0 { + i.ReferScope = expression.IdentReferFromTable + i.ReferIndex = idx[0] + return i, nil + } + + if v.rootIdent != i { + // This identifier is the part of the order by, check ambiguous here. + index, err = checkIdent(i, v.selectList, orderByClause) + if err != nil { + return nil, errors.Trace(err) + } + } + + // try to find in select list, we have got index using checkIdent before. + if index >= 0 { + // find in select list + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, nil + } + + // TODO: check in out query + // TODO: return unknown field error, but now just return directly. + // Because this may reference outer query. + return i, nil +} + +func (v *orderByVisitor) VisitPosition(p *expression.Position) (expression.Expression, error) { + n := p.N + if p.N <= v.selectList.HiddenFieldOffset { + // this position expression is not in hidden field, no need to check + return p, nil + } + + // we have added order by to hidden field before, now visit it here. + expr := v.selectList.Fields[n-1].Expr + e, err := expr.Accept(v) + if err != nil { + return nil, errors.Trace(err) + } + v.selectList.Fields[n-1].Expr = e + + return p, nil } // Plan get SrcPlan/OrderByDefaultPlan. @@ -120,45 +186,28 @@ func (r *OrderByRset) Plan(ctx context.Context) (plan.Plan, error) { ascs []bool ) - fields := r.Src.GetFields() + visitor := &orderByVisitor{} + visitor.BaseVisitor.V = visitor + visitor.selectList = r.SelectList + for i := range r.By { e := r.By[i].Expr - if v, ok := e.(expression.Value); ok { - var ( - position int - isPosition = true - ) - switch u := v.Val.(type) { - case int64: - position = int(u) - case uint64: - position = int(u) - default: - isPosition = false - // only const value - } + pos, err := castPosition(e, r.SelectList, orderByClause) + if err != nil { + return nil, errors.Trace(err) + } - if isPosition { - if position < 1 || position > len(fields) { - return nil, errors.Errorf("Unknown column '%d' in 'order clause'", position) - } - - // use Position expression for the associated field. - r.By[i].Expr = &expression.Position{N: position} - } + if pos != nil { + // use Position expression for the associated field. + r.By[i].Expr = pos } else { - // Don't check ambiguous here, only check field exists or not. - // TODO: use visitor to refactor. - colNames := expression.MentionedColumns(e) - for _, name := range colNames { - if idx := field.GetResultFieldIndex(name, r.SelectList.ResultFields, field.DefaultFieldFlag); len(idx) == 0 { - // find in from - if idx = field.GetResultFieldIndex(name, r.SelectList.FromFields, field.DefaultFieldFlag); len(idx) == 0 { - return nil, errors.Errorf("unknown field %s", name) - } - } + visitor.rootIdent = castIdent(e) + e, err = e.Accept(visitor) + if err != nil { + return nil, errors.Trace(err) } + r.By[i].Expr = e } by = append(by, r.By[i].Expr) diff --git a/rset/rsets/orderby_test.go b/rset/rsets/orderby_test.go index afd4b5313e..9a51c381b7 100644 --- a/rset/rsets/orderby_test.go +++ b/rset/rsets/orderby_test.go @@ -47,64 +47,6 @@ func (s *testOrderByRsetSuite) SetUpSuite(c *C) { s.r = &OrderByRset{Src: tblPlan, SelectList: selectList} } -func (s *testOrderByRsetSuite) TestOrderByRsetCheckAndUpdateSelectList(c *C) { - resultFields := s.r.Src.GetFields() - - fields := make([]*field.Field, len(resultFields)) - for i, resultField := range resultFields { - name := resultField.Name - fields[i] = &field.Field{Expr: &expression.Ident{CIStr: model.NewCIStr(name)}} - } - - selectList := &plans.SelectList{ - HiddenFieldOffset: len(resultFields), - ResultFields: resultFields, - Fields: fields, - } - - expr := &expression.Ident{CIStr: model.NewCIStr("id")} - orderByItem := OrderByItem{Expr: expr, Asc: true} - by := []OrderByItem{orderByItem} - r := &OrderByRset{By: by, SelectList: selectList} - - // `select id, name from t order by id` - err := r.CheckAndUpdateSelectList(selectList, resultFields) - c.Assert(err, IsNil) - - // `select id, name as id from t order by id` - selectList.Fields[1].AsName = "id" - selectList.ResultFields[1].Name = "id" - - err = r.CheckAndUpdateSelectList(selectList, resultFields) - c.Assert(err, NotNil) - - // `select id, name from t order by count(1) > 1` - aggExpr, err := expression.NewCall("count", []expression.Expression{expression.Value{Val: 1}}, false) - c.Assert(err, IsNil) - - r.By[0].Expr = aggExpr - - err = r.CheckAndUpdateSelectList(selectList, resultFields) - c.Assert(err, IsNil) - - // `select id, name from t order by count(xxx) > 1` - aggExpr, err = expression.NewCall("count", []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("xxx")}}, false) - c.Assert(err, IsNil) - - r.By[0].Expr = aggExpr - - err = r.CheckAndUpdateSelectList(selectList, resultFields) - c.Assert(err, NotNil) - - // `select id, name from t order by xxx` - r.By[0].Expr = &expression.Ident{CIStr: model.NewCIStr("xxx")} - selectList.Fields[1].AsName = "name" - selectList.ResultFields[1].Name = "name" - - err = r.CheckAndUpdateSelectList(selectList, resultFields) - c.Assert(err, NotNil) -} - func (s *testOrderByRsetSuite) TestOrderByRsetPlan(c *C) { // `select id, name from t` _, err := s.r.Plan(nil) @@ -145,12 +87,6 @@ func (s *testOrderByRsetSuite) TestOrderByRsetPlan(c *C) { _, err = s.r.Plan(nil) c.Assert(err, IsNil) - // `select id, name from t order by xxx` - s.r.By[0].Expr = &expression.Ident{CIStr: model.NewCIStr("xxx")} - - _, err = s.r.Plan(nil) - c.Assert(err, NotNil) - // check src plan is NullPlan s.r.Src = &plans.NullPlan{Fields: s.r.SelectList.ResultFields} diff --git a/stmt/stmts/select.go b/stmt/stmts/select.go index 8b889dc5a2..1b0b48b16d 100644 --- a/stmt/stmts/select.go +++ b/stmt/stmts/select.go @@ -182,7 +182,7 @@ func (s *SelectStmt) Plan(ctx context.Context) (plan.Plan, error) { if s.OrderBy != nil { // `order by` may contain aggregate functions, and we will add this to hidden fields. - if err = s.OrderBy.CheckAndUpdateSelectList(selectList, r.GetFields()); err != nil { + if err = s.OrderBy.CheckAggregate(selectList); err != nil { return nil, errors.Trace(err) } } diff --git a/stmt/stmts/union.go b/stmt/stmts/union.go index 90390a7187..bd569f1c67 100644 --- a/stmt/stmts/union.go +++ b/stmt/stmts/union.go @@ -92,13 +92,14 @@ func (s *UnionStmt) Plan(ctx context.Context) (plan.Plan, error) { selectList := &plans.SelectList{} selectList.ResultFields = make([]*field.ResultField, len(fields)) selectList.HiddenFieldOffset = len(fields) + selectList.Fields = s.Selects[0].Fields // Union uses first select return column names and ignores table name. // We only care result name and type here. for i, f := range fields { - nf := &field.ResultField{} - nf.Name = f.Name - nf.FieldType = f.FieldType + nf := f.Clone() + nf.OrgTableName = "" + nf.TableName = "" selectList.ResultFields[i] = nf } @@ -109,6 +110,7 @@ func (s *UnionStmt) Plan(ctx context.Context) (plan.Plan, error) { r = &plans.UnionPlan{Srcs: srcs, Distincts: s.Distincts, RFields: selectList.ResultFields} + // TODO: check aggregate function later. if s := s.OrderBy; s != nil { if r, err = (&rsets.OrderByRset{By: s.By, Src: r, diff --git a/tidb_test.go b/tidb_test.go index aa082ab6a7..120a64e719 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -1165,11 +1165,12 @@ func (s *testSessionSuite) TestOrderBy(c *C) { mustExecMatch(c, se, "select c1 as a, t.c1 as a from t order by a desc", [][]interface{}{{2, 2}, {1, 1}}) mustExecMatch(c, se, "select c1 as c2 from t order by c2", [][]interface{}{{1}, {2}}) mustExecMatch(c, se, "select sum(c1) from t order by sum(c1)", [][]interface{}{{3}}) - - // TODO: now this test result is not same as MySQL, we will update it later. - mustExecMatch(c, se, "select c1 as c2 from t order by c2 + 1", [][]interface{}{{1}, {2}}) + mustExecMatch(c, se, "select c1 as c2 from t order by c2 + 1", [][]interface{}{{2}, {1}}) mustExecFailed(c, se, "select c1 as a, c2 as a from t order by a") + + mustExecFailed(c, se, "(select c1 as c2, c2 from t) union (select c1, c2 from t) order by c2") + mustExecFailed(c, se, "(select c1 as c2, c2 from t) union (select c1, c2 from t) order by c1") } func (s *testSessionSuite) TestHaving(c *C) { @@ -1263,6 +1264,10 @@ func mustExecMatch(c *C, se Session, sql string, expected [][]interface{}) { } func mustExecFailed(c *C, se Session, sql string, args ...interface{}) { - _, err := exec(c, se, sql, args...) + r, err := exec(c, se, sql, args...) + if err == nil { + // sometimes we may meet error after executing first row. + _, err = r.FirstRow() + } c.Assert(err, NotNil) } From cd4e4072a4a200b2c1cd0ad57f42ea8e74bb5741 Mon Sep 17 00:00:00 2001 From: dongxu Date: Sat, 17 Oct 2015 14:55:06 +0800 Subject: [PATCH 26/54] mvcc: add compactor implementation --- kv/compactor.go | 29 ++++ kv/kv.go | 1 + store/localstore/compactor.go | 189 +++++++++++++++++++++ store/localstore/compactor_test.go | 74 ++++++++ store/localstore/gc_bench/main.go | 74 ++++++++ store/localstore/kv.go | 21 ++- store/localstore/local_gc.go | 76 --------- store/localstore/local_version_provider.go | 4 + store/localstore/txn.go | 5 +- 9 files changed, 391 insertions(+), 82 deletions(-) create mode 100644 kv/compactor.go create mode 100644 store/localstore/compactor.go create mode 100644 store/localstore/compactor_test.go create mode 100644 store/localstore/gc_bench/main.go delete mode 100644 store/localstore/local_gc.go diff --git a/kv/compactor.go b/kv/compactor.go new file mode 100644 index 0000000000..b46e56237c --- /dev/null +++ b/kv/compactor.go @@ -0,0 +1,29 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package kv + +import "time" + +type CompactorPolicy struct { + SafeTime int + TriggerInterval time.Duration + BatchDeleteSize int +} + +type Compactor interface { + OnGet(k Key) + OnSet(k Key) + OnDelete(k Key) + Compact(ctx interface{}, k Key) error +} diff --git a/kv/kv.go b/kv/kv.go index 6098f0f604..90de6933ef 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -164,6 +164,7 @@ type Storage interface { Close() error // Storage's unique ID UUID() string + DumpRaw() } // FnKeyCmp is the function for iterator the keys diff --git a/store/localstore/compactor.go b/store/localstore/compactor.go new file mode 100644 index 0000000000..6cf12dd005 --- /dev/null +++ b/store/localstore/compactor.go @@ -0,0 +1,189 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package localstore + +import ( + "sync" + "time" + + "github.com/juju/errors" + "github.com/ngaut/log" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/localstore/engine" + "github.com/pingcap/tidb/util/bytes" +) + +var _ kv.Compactor = (*localstoreCompactor)(nil) + +const ( + deleteWorkerSize = 3 +) + +var localCompactorDefaultPolicy = kv.CompactorPolicy{ + SafeTime: 20 * 1000, // in ms + TriggerInterval: 1 * time.Second, + BatchDeleteSize: 100, +} + +type localstoreCompactor struct { + mu sync.Mutex + recentKeys map[string]struct{} + stopChan chan struct{} + delChan chan kv.EncodedKey + ticker *time.Ticker + db engine.DB + policy kv.CompactorPolicy +} + +func (gc *localstoreCompactor) OnSet(k kv.Key) { + gc.mu.Lock() + defer gc.mu.Unlock() + gc.recentKeys[string(k)] = struct{}{} +} + +func (gc *localstoreCompactor) OnGet(k kv.Key) { + // Do nothing now. +} + +func (gc *localstoreCompactor) OnDelete(k kv.Key) { + gc.mu.Lock() + defer gc.mu.Unlock() + gc.recentKeys[string(k)] = struct{}{} +} + +func (gc *localstoreCompactor) getAllVersions(k kv.Key) ([]kv.EncodedKey, error) { + startKey := MvccEncodeVersionKey(k, kv.MaxVersion) + endKey := MvccEncodeVersionKey(k, kv.MinVersion) + + it, err := gc.db.Seek(startKey) + if err != nil { + return nil, errors.Trace(err) + } + defer it.Release() + + var ret []kv.EncodedKey + for it.Next() { + if kv.EncodedKey(it.Key()).Cmp(endKey) < 0 { + ret = append(ret, bytes.CloneBytes(kv.EncodedKey(it.Key()))) + } + } + return ret, nil +} + +func (gc *localstoreCompactor) deleteWorker() { + cnt := 0 + batch := gc.db.NewBatch() +L: + for { + select { + case <-gc.stopChan: + break L + case key := <-gc.delChan: + { + cnt++ + batch.Delete(key) + // Batch delete. + if cnt == gc.policy.BatchDeleteSize { + err := gc.db.Commit(batch) + if err != nil { + log.Error(err) + } + batch = gc.db.NewBatch() + cnt = 0 + } + } + } + } +} + +func (gc *localstoreCompactor) checkExpiredKeysWorker() { +L: + for { + select { + case <-gc.stopChan: + break L + case <-gc.ticker.C: + log.Info("GC trigger") + gc.mu.Lock() + m := gc.recentKeys + gc.recentKeys = make(map[string]struct{}) + gc.mu.Unlock() + // Do Compactor + if len(m) > 0 { + for k, _ := range m { + err := gc.Compact(nil, []byte(k)) + if err != nil { + log.Error(err) + } + } + } + } + } + log.Info("GC Stopped") +} + +func (gc *localstoreCompactor) filterExpiredKeys(keys []kv.EncodedKey) []kv.EncodedKey { + var ret []kv.EncodedKey + for _, k := range keys { + _, ver, err := MvccDecode(k) + if err != nil { + // Should not happen. + panic(err) + } + ts := LocalVersionToTimestamp(ver) + currentTs := time.Now().UnixNano() / int64(time.Millisecond) + // Check timeout keys. + if currentTs-int64(ts) >= int64(gc.policy.SafeTime) { + ret = append(ret, k) + } + } + return ret +} + +func (gc *localstoreCompactor) Compact(ctx interface{}, k kv.Key) error { + keys, err := gc.getAllVersions(k) + if err != nil { + return errors.Trace(err) + } + for _, key := range gc.filterExpiredKeys(keys) { + // Send timeout key to deleteWorker. + log.Info("GC send key to deleteWorker", key) + gc.delChan <- key + } + return nil +} + +func (gc *localstoreCompactor) Start() { + // Start workers. + go gc.checkExpiredKeysWorker() + for i := 0; i < deleteWorkerSize; i++ { + go gc.deleteWorker() + } +} + +func (gc *localstoreCompactor) Stop() { + gc.ticker.Stop() + close(gc.stopChan) +} + +func newLocalCompactor(policy kv.CompactorPolicy, db engine.DB) *localstoreCompactor { + return &localstoreCompactor{ + recentKeys: make(map[string]struct{}), + stopChan: make(chan struct{}), + delChan: make(chan kv.EncodedKey, 100), + ticker: time.NewTicker(policy.TriggerInterval), + policy: policy, + db: db, + } +} diff --git a/store/localstore/compactor_test.go b/store/localstore/compactor_test.go new file mode 100644 index 0000000000..27cab4a379 --- /dev/null +++ b/store/localstore/compactor_test.go @@ -0,0 +1,74 @@ +package localstore + +import ( + "time" + + "github.com/ngaut/log" + . "github.com/pingcap/check" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/localstore/engine" +) + +var _ = Suite(&localstoreCompactorTestSuite{}) + +type localstoreCompactorTestSuite struct { +} + +func count(db engine.DB) int { + it, _ := db.Seek([]byte{0}) + defer it.Release() + totalCnt := 0 + for it.Next() { + log.Error(it.Key()) + totalCnt++ + } + return totalCnt +} + +func (s *localstoreCompactorTestSuite) TestCompactor(c *C) { + store := createMemStore() + db := store.(*dbStore).db + store.(*dbStore).compactor.Stop() + + policy := kv.CompactorPolicy{ + SafeTime: 500, + BatchDeleteSize: 1, + TriggerInterval: 100 * time.Millisecond, + } + compactor := newLocalCompactor(policy, db) + store.(*dbStore).compactor = compactor + + compactor.Start() + + txn, _ := store.Begin() + txn.Set([]byte("a"), []byte("1")) + txn.Commit() + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("2")) + txn.Commit() + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("3")) + txn.Commit() + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("3")) + txn.Commit() + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("4")) + txn.Commit() + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("5")) + txn.Commit() + t := count(db) + c.Assert(t, Equals, 7) + + // Simulating timeout + time.Sleep(1 * time.Second) + // Touch a, tigger GC + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("b")) + txn.Commit() + time.Sleep(1 * time.Second) + // Do background GC + t = count(db) + c.Assert(t, Equals, 2) +} diff --git a/store/localstore/gc_bench/main.go b/store/localstore/gc_bench/main.go new file mode 100644 index 0000000000..893b08c1ff --- /dev/null +++ b/store/localstore/gc_bench/main.go @@ -0,0 +1,74 @@ +package main + +import ( + "flag" + "fmt" + "time" + + "github.com/ngaut/log" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/localstore" + "github.com/pingcap/tidb/store/localstore/goleveldb" +) + +var ( + store kv.Storage + logLevel = flag.String("L", "error", "Log level") +) + +// init memory store +func init() { + path := fmt.Sprintf("memory://%d", time.Now().UnixNano()) + d := localstore.Driver{ + goleveldb.MemoryDriver{}, + } + var err error + store, err = d.Open(path) + if err != nil { + panic(err) + } +} + +func dump() { + startTs := time.Now() + tx, _ := store.Begin() + it, err := tx.Seek([]byte{0}, nil) + if err != nil { + log.Error(err) + } + cnt := 0 + for it.Valid() { + log.Info(it.Key(), it.Value()) + it, _ = it.Next(nil) + cnt++ + } + tx.Commit() + elapse := time.Since(startTs) + fmt.Println(cnt, elapse) +} + +func renew() { + tx, _ := store.Begin() + for i := 0; i < 10000; i++ { + key := fmt.Sprintf("record-%d", i) + val := fmt.Sprintf("%d", time.Now().Unix()) + tx.Set([]byte(key), []byte(val)) + } + tx.Commit() +} + +func main() { + flag.Parse() + log.SetLevelByString(*logLevel) + + for i := 0; i < 100; i++ { + fmt.Printf("\n====== Round %d =====\n", i) + if i%2 == 0 { + renew() + dump() + store.DumpRaw() + } + fmt.Println("====================") + time.Sleep(2 * time.Second) + } +} diff --git a/store/localstore/kv.go b/store/localstore/kv.go index 9871f7fda5..4df0665fdd 100644 --- a/store/localstore/kv.go +++ b/store/localstore/kv.go @@ -14,6 +14,7 @@ package localstore import ( + "fmt" "sync" "time" @@ -38,7 +39,7 @@ type dbStore struct { keysLocked map[string]uint64 uuid string path string - gc *localstoreGC + compactor *localstoreCompactor } type storeCache struct { @@ -86,11 +87,10 @@ func (d Driver) Open(schema string) (kv.Storage, error) { uuid: uuid.NewV4().String(), path: schema, db: db, - gc: newLocalGC(), + compactor: newLocalCompactor(localCompactorDefaultPolicy, db), } - mc.cache[schema] = s - + s.compactor.Start() return s, nil } @@ -141,6 +141,7 @@ func (s *dbStore) Begin() (kv.Transaction, error) { func (s *dbStore) Close() error { mc.mu.Lock() defer mc.mu.Unlock() + s.compactor.Stop() delete(mc.cache, s.path) return s.db.Close() } @@ -196,6 +197,18 @@ func (s *dbStore) tryConditionLockKey(tID uint64, key string, snapshotVal []byte return nil } +func (s *dbStore) DumpRaw() { + startTs := time.Now() + it, _ := s.db.Seek([]byte{0}) + defer it.Release() + cnt := 0 + for it.Next() { + cnt++ + } + elapse := time.Since(startTs) + fmt.Println(cnt, elapse) +} + func (s *dbStore) unLockKeys(keys ...string) error { s.mu.Lock() defer s.mu.Unlock() diff --git a/store/localstore/local_gc.go b/store/localstore/local_gc.go deleted file mode 100644 index eafa18b97f..0000000000 --- a/store/localstore/local_gc.go +++ /dev/null @@ -1,76 +0,0 @@ -package localstore - -import ( - "container/list" - "sync" - "time" - - "github.com/ngaut/log" - "github.com/pingcap/tidb/kv" -) - -var _ kv.GC = (*localstoreGC)(nil) - -type localstoreGC struct { - mu sync.Mutex - recentKeys map[string]struct{} - stopChan chan struct{} - ticker *time.Ticker -} - -func (gc *localstoreGC) OnSet(k kv.Key) { - gc.mu.Lock() - defer gc.mu.Unlock() - gc.recentKeys[string(k)] = struct{}{} -} - -func (gc *localstoreGC) OnGet(k kv.Key) { - gc.mu.Lock() - defer gc.mu.Unlock() - gc.recentKeys[string(k)] = struct{}{} -} - -func (gc *localstoreGC) Do(k kv.Key) { - // TODO - log.Debugf("gc key: %q", k) -} - -func (gc *localstoreGC) Start() { - go func() { - for { - select { - case <-gc.stopChan: - break - case <-gc.ticker.C: - l := list.New() - gc.mu.Lock() - for k, _ := range gc.recentKeys { - // get recent keys list - l.PushBack(k) - } - // clean recentKeys - gc.recentKeys = map[string]struct{}{} - gc.mu.Unlock() - // Do GC - for e := l.Front(); e != nil; e = e.Next() { - gc.Do([]byte(e.Value.(string))) - } - } - } - }() -} - -func (gc *localstoreGC) Stop() { - gc.stopChan <- struct{}{} - gc.ticker.Stop() - close(gc.stopChan) -} - -func newLocalGC() *localstoreGC { - return &localstoreGC{ - recentKeys: map[string]struct{}{}, - stopChan: make(chan struct{}), - // TODO hard code - ticker: time.NewTicker(time.Second * 1), - } -} diff --git a/store/localstore/local_version_provider.go b/store/localstore/local_version_provider.go index f22ffd6778..1502f46016 100644 --- a/store/localstore/local_version_provider.go +++ b/store/localstore/local_version_provider.go @@ -41,3 +41,7 @@ func (l *LocalVersionProvider) CurrentVersion() (kv.Version, error) { l.n = 0 return kv.Version{Ver: ts}, nil } + +func LocalVersionToTimestamp(ver kv.Version) uint64 { + return ver.Ver >> timePrecisionOffset +} diff --git a/store/localstore/txn.go b/store/localstore/txn.go index ddceb5b7be..04d00ecf81 100644 --- a/store/localstore/txn.go +++ b/store/localstore/txn.go @@ -95,6 +95,7 @@ func (txn *dbTxn) Inc(k kv.Key, step int64) (int64, error) { if err != nil { return 0, errors.Trace(err) } + txn.store.compactor.OnSet(k) return intVal, nil } @@ -128,7 +129,7 @@ func (txn *dbTxn) Get(k kv.Key) ([]byte, error) { if len(val) == 0 { return nil, errors.Trace(kv.ErrNotExist) } - txn.store.gc.OnGet(k) + txn.store.compactor.OnGet(k) return val, nil } @@ -145,7 +146,7 @@ func (txn *dbTxn) Set(k kv.Key, data []byte) error { if err != nil { return errors.Trace(err) } - txn.store.gc.OnSet(k) + txn.store.compactor.OnSet(k) return nil } From e8f8fdb380aa21f6377567c35879eb2f8af941e7 Mon Sep 17 00:00:00 2001 From: dongxu Date: Sat, 17 Oct 2015 15:02:43 +0800 Subject: [PATCH 27/54] mvcc-gc: fix bug --- store/localstore/compactor.go | 7 +++++++ store/localstore/compactor_test.go | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/store/localstore/compactor.go b/store/localstore/compactor.go index 6cf12dd005..88c39e54b8 100644 --- a/store/localstore/compactor.go +++ b/store/localstore/compactor.go @@ -135,6 +135,8 @@ L: func (gc *localstoreCompactor) filterExpiredKeys(keys []kv.EncodedKey) []kv.EncodedKey { var ret []kv.EncodedKey + first := true + // keys are always in descending order. for _, k := range keys { _, ver, err := MvccDecode(k) if err != nil { @@ -145,6 +147,11 @@ func (gc *localstoreCompactor) filterExpiredKeys(keys []kv.EncodedKey) []kv.Enco currentTs := time.Now().UnixNano() / int64(time.Millisecond) // Check timeout keys. if currentTs-int64(ts) >= int64(gc.policy.SafeTime) { + // Skip first version. + if first { + first = false + continue + } ret = append(ret, k) } } diff --git a/store/localstore/compactor_test.go b/store/localstore/compactor_test.go index 27cab4a379..c70f971248 100644 --- a/store/localstore/compactor_test.go +++ b/store/localstore/compactor_test.go @@ -70,5 +70,5 @@ func (s *localstoreCompactorTestSuite) TestCompactor(c *C) { time.Sleep(1 * time.Second) // Do background GC t = count(db) - c.Assert(t, Equals, 2) + c.Assert(t, Equals, 3) } From cb54395674a6205db14afc905d8819339ec6c155 Mon Sep 17 00:00:00 2001 From: dongxu Date: Sat, 17 Oct 2015 15:08:44 +0800 Subject: [PATCH 28/54] mvcc: add missing license header --- store/localstore/compactor_test.go | 13 +++++++++++++ store/localstore/gc_bench/main.go | 13 +++++++++++++ store/localstore/mvcc.go | 13 +++++++++++++ store/localstore/mvcc_test.go | 13 +++++++++++++ 4 files changed, 52 insertions(+) diff --git a/store/localstore/compactor_test.go b/store/localstore/compactor_test.go index c70f971248..051160d4e3 100644 --- a/store/localstore/compactor_test.go +++ b/store/localstore/compactor_test.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package localstore import ( diff --git a/store/localstore/gc_bench/main.go b/store/localstore/gc_bench/main.go index 893b08c1ff..fe8adc31a9 100644 --- a/store/localstore/gc_bench/main.go +++ b/store/localstore/gc_bench/main.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package main import ( diff --git a/store/localstore/mvcc.go b/store/localstore/mvcc.go index 9d475c4fad..f61296eee2 100644 --- a/store/localstore/mvcc.go +++ b/store/localstore/mvcc.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package localstore import ( diff --git a/store/localstore/mvcc_test.go b/store/localstore/mvcc_test.go index e0add9217d..db187e1039 100644 --- a/store/localstore/mvcc_test.go +++ b/store/localstore/mvcc_test.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package localstore import ( From abc90819677b6b3a1695e71df925bc3ca7c9a56f Mon Sep 17 00:00:00 2001 From: dongxu Date: Sat, 17 Oct 2015 15:15:17 +0800 Subject: [PATCH 29/54] mvcc-gc: address comments --- kv/compactor.go | 2 +- kv/kv.go | 1 - store/localstore/compactor.go | 22 ++++---- store/localstore/compactor_test.go | 2 +- store/localstore/gc_bench/main.go | 87 ------------------------------ store/localstore/kv.go | 13 ----- store/localstore/txn.go | 7 ++- 7 files changed, 19 insertions(+), 115 deletions(-) delete mode 100644 store/localstore/gc_bench/main.go diff --git a/kv/compactor.go b/kv/compactor.go index b46e56237c..702f71efa1 100644 --- a/kv/compactor.go +++ b/kv/compactor.go @@ -16,7 +16,7 @@ package kv import "time" type CompactorPolicy struct { - SafeTime int + SafePoint int TriggerInterval time.Duration BatchDeleteSize int } diff --git a/kv/kv.go b/kv/kv.go index 01223b61b2..b93e9dc3a2 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -164,7 +164,6 @@ type Storage interface { Close() error // Storage's unique ID UUID() string - DumpRaw() } // FnKeyCmp is the function for iterator the keys diff --git a/store/localstore/compactor.go b/store/localstore/compactor.go index 88c39e54b8..0e7a12b7ab 100644 --- a/store/localstore/compactor.go +++ b/store/localstore/compactor.go @@ -31,7 +31,7 @@ const ( ) var localCompactorDefaultPolicy = kv.CompactorPolicy{ - SafeTime: 20 * 1000, // in ms + SafePoint: 20 * 1000, // in ms TriggerInterval: 1 * time.Second, BatchDeleteSize: 100, } @@ -39,8 +39,8 @@ var localCompactorDefaultPolicy = kv.CompactorPolicy{ type localstoreCompactor struct { mu sync.Mutex recentKeys map[string]struct{} - stopChan chan struct{} - delChan chan kv.EncodedKey + stopCh chan struct{} + delCh chan kv.EncodedKey ticker *time.Ticker db engine.DB policy kv.CompactorPolicy @@ -87,9 +87,9 @@ func (gc *localstoreCompactor) deleteWorker() { L: for { select { - case <-gc.stopChan: + case <-gc.stopCh: break L - case key := <-gc.delChan: + case key := <-gc.delCh: { cnt++ batch.Delete(key) @@ -111,7 +111,7 @@ func (gc *localstoreCompactor) checkExpiredKeysWorker() { L: for { select { - case <-gc.stopChan: + case <-gc.stopCh: break L case <-gc.ticker.C: log.Info("GC trigger") @@ -146,7 +146,7 @@ func (gc *localstoreCompactor) filterExpiredKeys(keys []kv.EncodedKey) []kv.Enco ts := LocalVersionToTimestamp(ver) currentTs := time.Now().UnixNano() / int64(time.Millisecond) // Check timeout keys. - if currentTs-int64(ts) >= int64(gc.policy.SafeTime) { + if currentTs-int64(ts) >= int64(gc.policy.SafePoint) { // Skip first version. if first { first = false @@ -166,7 +166,7 @@ func (gc *localstoreCompactor) Compact(ctx interface{}, k kv.Key) error { for _, key := range gc.filterExpiredKeys(keys) { // Send timeout key to deleteWorker. log.Info("GC send key to deleteWorker", key) - gc.delChan <- key + gc.delCh <- key } return nil } @@ -181,14 +181,14 @@ func (gc *localstoreCompactor) Start() { func (gc *localstoreCompactor) Stop() { gc.ticker.Stop() - close(gc.stopChan) + close(gc.stopCh) } func newLocalCompactor(policy kv.CompactorPolicy, db engine.DB) *localstoreCompactor { return &localstoreCompactor{ recentKeys: make(map[string]struct{}), - stopChan: make(chan struct{}), - delChan: make(chan kv.EncodedKey, 100), + stopCh: make(chan struct{}), + delCh: make(chan kv.EncodedKey, 100), ticker: time.NewTicker(policy.TriggerInterval), policy: policy, db: db, diff --git a/store/localstore/compactor_test.go b/store/localstore/compactor_test.go index 051160d4e3..8b5d190547 100644 --- a/store/localstore/compactor_test.go +++ b/store/localstore/compactor_test.go @@ -44,7 +44,7 @@ func (s *localstoreCompactorTestSuite) TestCompactor(c *C) { store.(*dbStore).compactor.Stop() policy := kv.CompactorPolicy{ - SafeTime: 500, + SafePoint: 500, BatchDeleteSize: 1, TriggerInterval: 100 * time.Millisecond, } diff --git a/store/localstore/gc_bench/main.go b/store/localstore/gc_bench/main.go deleted file mode 100644 index fe8adc31a9..0000000000 --- a/store/localstore/gc_bench/main.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "flag" - "fmt" - "time" - - "github.com/ngaut/log" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/store/localstore" - "github.com/pingcap/tidb/store/localstore/goleveldb" -) - -var ( - store kv.Storage - logLevel = flag.String("L", "error", "Log level") -) - -// init memory store -func init() { - path := fmt.Sprintf("memory://%d", time.Now().UnixNano()) - d := localstore.Driver{ - goleveldb.MemoryDriver{}, - } - var err error - store, err = d.Open(path) - if err != nil { - panic(err) - } -} - -func dump() { - startTs := time.Now() - tx, _ := store.Begin() - it, err := tx.Seek([]byte{0}, nil) - if err != nil { - log.Error(err) - } - cnt := 0 - for it.Valid() { - log.Info(it.Key(), it.Value()) - it, _ = it.Next(nil) - cnt++ - } - tx.Commit() - elapse := time.Since(startTs) - fmt.Println(cnt, elapse) -} - -func renew() { - tx, _ := store.Begin() - for i := 0; i < 10000; i++ { - key := fmt.Sprintf("record-%d", i) - val := fmt.Sprintf("%d", time.Now().Unix()) - tx.Set([]byte(key), []byte(val)) - } - tx.Commit() -} - -func main() { - flag.Parse() - log.SetLevelByString(*logLevel) - - for i := 0; i < 100; i++ { - fmt.Printf("\n====== Round %d =====\n", i) - if i%2 == 0 { - renew() - dump() - store.DumpRaw() - } - fmt.Println("====================") - time.Sleep(2 * time.Second) - } -} diff --git a/store/localstore/kv.go b/store/localstore/kv.go index 4f34a979db..9e45a2d873 100644 --- a/store/localstore/kv.go +++ b/store/localstore/kv.go @@ -14,7 +14,6 @@ package localstore import ( - "fmt" "sync" "time" @@ -196,18 +195,6 @@ func (s *dbStore) tryConditionLockKey(tid uint64, key string, snapshotVal []byte return nil } -func (s *dbStore) DumpRaw() { - startTs := time.Now() - it, _ := s.db.Seek([]byte{0}) - defer it.Release() - cnt := 0 - for it.Next() { - cnt++ - } - elapse := time.Since(startTs) - fmt.Println(cnt, elapse) -} - func (s *dbStore) unLockKeys(keys ...string) error { s.mu.Lock() defer s.mu.Unlock() diff --git a/store/localstore/txn.go b/store/localstore/txn.go index 27a254dff0..334ab75360 100644 --- a/store/localstore/txn.go +++ b/store/localstore/txn.go @@ -174,7 +174,12 @@ func (txn *dbTxn) Seek(k kv.Key, fnKeyCmp func(kv.Key) bool) (kv.Iterator, error func (txn *dbTxn) Delete(k kv.Key) error { log.Debugf("delete key:%q, txn:%d", k, txn.tid) k = kv.EncodeKey(k) - return txn.UnionStore.Delete(k) + err := txn.UnionStore.Delete(k) + if err != nil { + return errors.Trace(err) + } + txn.store.compactor.OnDelete(k) + return nil } func (txn *dbTxn) each(f func(iterator.Iterator) error) error { From a70e4aac999a583832b193bc4c26bd206521eae4 Mon Sep 17 00:00:00 2001 From: dongxu Date: Sat, 17 Oct 2015 16:33:36 +0800 Subject: [PATCH 30/54] mvcc-gc: address review comments --- kv/compactor.go | 15 +++++++++++++-- store/localstore/compactor.go | 20 +++++++++----------- store/localstore/compactor_test.go | 2 +- store/localstore/local_version_provider.go | 2 +- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/kv/compactor.go b/kv/compactor.go index 702f71efa1..47b48d8123 100644 --- a/kv/compactor.go +++ b/kv/compactor.go @@ -15,15 +15,26 @@ package kv import "time" +// CompactorPolicy defines gc policy of MVCC storage. type CompactorPolicy struct { - SafePoint int + // SafePoint specifies + SafePoint int + // TriggerInterval specifies how often should the compactor + // scans outdated data. TriggerInterval time.Duration - BatchDeleteSize int + // BatchDeleteCnt specifies the batch size for + // delete outdated data transaction. + BatchDeleteCnt int } +// Compactor compacts MVCC storage. type Compactor interface { + // OnGet is the hook point on Txn.Get. OnGet(k Key) + // OnSet is the hook point on Txn.Set OnSet(k Key) + // OnDelete is the hook point on Txn.Delete OnDelete(k Key) + // Compact is the function removes the given key. Compact(ctx interface{}, k Key) error } diff --git a/store/localstore/compactor.go b/store/localstore/compactor.go index 0e7a12b7ab..96e983b109 100644 --- a/store/localstore/compactor.go +++ b/store/localstore/compactor.go @@ -27,13 +27,13 @@ import ( var _ kv.Compactor = (*localstoreCompactor)(nil) const ( - deleteWorkerSize = 3 + deleteWorkerCnt = 3 ) var localCompactorDefaultPolicy = kv.CompactorPolicy{ SafePoint: 20 * 1000, // in ms TriggerInterval: 1 * time.Second, - BatchDeleteSize: 100, + BatchDeleteCnt: 100, } type localstoreCompactor struct { @@ -94,7 +94,7 @@ L: cnt++ batch.Delete(key) // Batch delete. - if cnt == gc.policy.BatchDeleteSize { + if cnt == gc.policy.BatchDeleteCnt { err := gc.db.Commit(batch) if err != nil { log.Error(err) @@ -120,12 +120,10 @@ L: gc.recentKeys = make(map[string]struct{}) gc.mu.Unlock() // Do Compactor - if len(m) > 0 { - for k, _ := range m { - err := gc.Compact(nil, []byte(k)) - if err != nil { - log.Error(err) - } + for k := range m { + err := gc.Compact(nil, []byte(k)) + if err != nil { + log.Error(err) } } } @@ -143,7 +141,7 @@ func (gc *localstoreCompactor) filterExpiredKeys(keys []kv.EncodedKey) []kv.Enco // Should not happen. panic(err) } - ts := LocalVersionToTimestamp(ver) + ts := localVersionToTimestamp(ver) currentTs := time.Now().UnixNano() / int64(time.Millisecond) // Check timeout keys. if currentTs-int64(ts) >= int64(gc.policy.SafePoint) { @@ -174,7 +172,7 @@ func (gc *localstoreCompactor) Compact(ctx interface{}, k kv.Key) error { func (gc *localstoreCompactor) Start() { // Start workers. go gc.checkExpiredKeysWorker() - for i := 0; i < deleteWorkerSize; i++ { + for i := 0; i < deleteWorkerCnt; i++ { go gc.deleteWorker() } } diff --git a/store/localstore/compactor_test.go b/store/localstore/compactor_test.go index 8b5d190547..a8f3db0424 100644 --- a/store/localstore/compactor_test.go +++ b/store/localstore/compactor_test.go @@ -45,7 +45,7 @@ func (s *localstoreCompactorTestSuite) TestCompactor(c *C) { policy := kv.CompactorPolicy{ SafePoint: 500, - BatchDeleteSize: 1, + BatchDeleteCnt: 1, TriggerInterval: 100 * time.Millisecond, } compactor := newLocalCompactor(policy, db) diff --git a/store/localstore/local_version_provider.go b/store/localstore/local_version_provider.go index 2b35e7360a..ad63831341 100644 --- a/store/localstore/local_version_provider.go +++ b/store/localstore/local_version_provider.go @@ -43,6 +43,6 @@ func (l *LocalVersionProvider) CurrentVersion() (kv.Version, error) { return kv.Version{Ver: ts}, nil } -func LocalVersionToTimestamp(ver kv.Version) uint64 { +func localVersionToTimestamp(ver kv.Version) uint64 { return ver.Ver >> timePrecisionOffset } From 9e091caef6859e8445c9ee310dd3016f451815dd Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 18:28:41 +0800 Subject: [PATCH 31/54] *: support ident refer in having. --- plan/plans/groupby.go | 70 +++------- plan/plans/having.go | 9 +- plan/plans/having_test.go | 4 +- plan/plans/row_stack.go | 2 +- rset/rsets/groupby.go | 4 +- rset/rsets/having.go | 260 +++++++++++++++++++++++++++++++------- rset/rsets/orderby.go | 4 +- stmt/stmts/select.go | 5 +- 8 files changed, 246 insertions(+), 112 deletions(-) diff --git a/plan/plans/groupby.go b/plan/plans/groupby.go index d3af0b857b..fee4929b2f 100644 --- a/plan/plans/groupby.go +++ b/plan/plans/groupby.go @@ -135,11 +135,11 @@ func (r *GroupByDefaultPlan) evalNoneAggFields(ctx context.Context, out []interf return out[index], nil } - // TODO: remove following getting later. - v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) - if err == nil { - return v, nil - } + // // TODO: remove following getting later. + // v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) + // if err == nil { + // return v, nil + // } // try to find in outer query return getIdentValueFromOuterQuery(ctx, name) @@ -161,57 +161,19 @@ func (r *GroupByDefaultPlan) evalNoneAggFields(ctx context.Context, out []interf func (r *GroupByDefaultPlan) evalAggFields(ctx context.Context, out []interface{}, m map[interface{}]interface{}, in []interface{}) error { - // Eval aggregate field results in ctx - for i := range r.AggFields { - if i < r.HiddenFieldOffset { - m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { - if scope == expression.IdentReferFromTable { - return in[index], nil - } else if scope == expression.IdentReferSelectList { - return out[index], nil - } - - //TODO: remove following getting later - v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) - if err == nil { - return v, nil - } - - // try to find in outer query - return getIdentValueFromOuterQuery(ctx, name) - } - } else { - // having may contain aggregate function and we will add it to hidden field, - // and this field can retrieve the data in select list, e.g. - // select c1 as a from t having count(a) = 1 - // because all the select list data is generated before, so we can get it - // when handling hidden field. - m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { - if scope == expression.IdentReferFromTable { - return in[index], nil - } else if scope == expression.IdentReferSelectList { - return out[index], nil - } - - // TODO: remove later - v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) - if err == nil { - return v, nil - } - - // TODO: remove later - // if we can not find in table, we will try to find in un-hidden select list - // only hidden field can use this - v, err = r.getFieldValueByName(name, out) - if err == nil { - return v, nil - } - - // try to find in outer query - return getIdentValueFromOuterQuery(ctx, name) - } + m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return in[index], nil + } else if scope == expression.IdentReferSelectList { + return out[index], nil } + // try to find in outer query + return getIdentValueFromOuterQuery(ctx, name) + } + + // Eval aggregate field results in ctx + for i := range r.AggFields { // we must evaluate aggregate function only, e.g, select col1 + count(*) in (count(*)), // we cannot evaluate it directly here, because col1 + count(*) returns nil before AggDone phase, // so we don't evaluate count(*) in In expression, and will get an invalid data in AggDone phase for it. diff --git a/plan/plans/having.go b/plan/plans/having.go index 231f059eaa..62c64d3647 100644 --- a/plan/plans/having.go +++ b/plan/plans/having.go @@ -62,10 +62,11 @@ func (r *HavingPlan) Next(ctx context.Context) (row *plan.Row, err error) { if srcRow == nil || err != nil { return nil, errors.Trace(err) } - r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err0 := GetIdentValue(name, r.Src.GetFields(), srcRow.Data, field.CheckFieldFlag) - if err0 == nil { - return v, nil + r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return srcRow.FromData[index], nil + } else if scope == expression.IdentReferSelectList { + return srcRow.Data[index], nil } // try to find in outer query diff --git a/plan/plans/having_test.go b/plan/plans/having_test.go index a3777fe9f2..21a4fc19cf 100644 --- a/plan/plans/having_test.go +++ b/plan/plans/having_test.go @@ -38,7 +38,9 @@ func (t *testHavingPlan) TestHaving(c *C) { Expr: &expression.BinaryOperation{ Op: opcode.GE, L: &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: expression.IdentReferSelectList, + ReferIndex: 0, }, R: &expression.Value{ Val: 20, diff --git a/plan/plans/row_stack.go b/plan/plans/row_stack.go index f4482bcfca..7b10919d29 100644 --- a/plan/plans/row_stack.go +++ b/plan/plans/row_stack.go @@ -199,5 +199,5 @@ func getIdentValueFromOuterQuery(ctx context.Context, name string) (interface{}, } } - return nil, errors.Trace(err) + return nil, errors.Errorf("unknown field %s", name) } diff --git a/rset/rsets/groupby.go b/rset/rsets/groupby.go index 176eee5e93..2ac5328523 100644 --- a/rset/rsets/groupby.go +++ b/rset/rsets/groupby.go @@ -90,9 +90,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, } // TODO: check in out query - // TODO: return unknown field error, but now just return directly. - // Because this may reference outer query. - return i, nil + return i, errors.Errorf("Unknown column '%s' in 'group statement'", i) } func (v *groupByVisitor) VisitCall(c *expression.Call) (expression.Expression, error) { diff --git a/rset/rsets/having.go b/rset/rsets/having.go index 4a3ad02080..2105bfd0fe 100644 --- a/rset/rsets/having.go +++ b/rset/rsets/having.go @@ -31,6 +31,8 @@ type HavingRset struct { Src plan.Plan Expr expression.Expression SelectList *plans.SelectList + // Group by clause + GroupBy []expression.Expression } // CheckAggregate will check whether order by has aggregate function or not, @@ -48,64 +50,234 @@ func (r *HavingRset) CheckAggregate(selectList *plans.SelectList) error { return nil } -// CheckAndUpdateSelectList checks having fields validity and set hidden fields to selectList. -func (r *HavingRset) CheckAndUpdateSelectList(selectList *plans.SelectList, groupBy []expression.Expression, tableFields []*field.ResultField) error { - if expression.ContainAggregateFunc(r.Expr) { - expr, err := selectList.UpdateAggFields(r.Expr) - if err != nil { - return errors.Errorf("%s in 'having clause'", err.Error()) +type havingVisitor struct { + expression.BaseVisitor + selectList *plans.SelectList + + // for group by + groupBy []expression.Expression + + // true means we are visiting aggregate function arguments now. + inAggregate bool +} + +func (v *havingVisitor) visitIdentInAggregate(i *expression.Ident) (expression.Expression, error) { + // if we are visiting aggregate function arguments, the identifier first checks in from table, + // then in select list, and outer query finally. + + // find this identifier in FROM. + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + if len(idx) > 0 { + i.ReferScope = expression.IdentReferFromTable + i.ReferIndex = idx[0] + return i, nil + } + + // check in select list. + index, err := checkIdent(i, v.selectList, havingClause) + if err != nil { + return i, errors.Trace(err) + } + + if index >= 0 { + // find in select list + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, nil + } + + // TODO: check in out query + // TODO: return unknown field error, but now just return directly. + // Because this may reference outer query. + return i, nil +} + +func (v *havingVisitor) checkIdentInGroupBy(i *expression.Ident) (*expression.Ident, bool, error) { + for _, by := range v.groupBy { + e := castIdent(by) + if e == nil { + // group by must be a identifier too + continue } - r.Expr = expr - } else { - // having can only contain group by column and select list, e.g, - // `select c1 from t group by c2 having c3 > 0` is invalid, - // because c3 is not in group by and select list. - names := expression.MentionedColumns(r.Expr) - for _, name := range names { - found := false + if !field.CheckFieldsEqual(e.L, i.L) { + // not same, continue + continue + } - // check name whether in select list. - // notice that `select c1 as c2 from t group by c1, c2, c3 having c2 > c3`, - // will use t.c2 not t.c1 here. - if field.ContainFieldName(name, selectList.ResultFields, field.OrgFieldNameFlag) { - continue - } - if field.ContainFieldName(name, selectList.ResultFields, field.FieldNameFlag) { - if field.ContainFieldName(name, tableFields, field.OrgFieldNameFlag) { - selectList.CloneHiddenField(name, tableFields) - } - continue - } + // if group by references select list or having is qualified identifier, + // no other check. + if e.ReferScope == expression.IdentReferSelectList || field.IsQualifiedName(i.L) { + i.ReferScope = e.ReferScope + i.ReferIndex = e.ReferIndex + return i, true, nil + } - // check name whether in group by. - // group by must only have column name, e.g, - // `select c1 from t group by c2 having c2 > 0` is valid, - // but `select c1 from t group by c2 + 1 having c2 > 0` is invalid. - for _, by := range groupBy { - if !field.CheckFieldsEqual(name, by.String()) { - continue - } + // having is unqualified name, e.g, select * from t1, t2 group by t1.c having c. + // both t1 and t2 have column c, we must check ambiguous here. - // if name is not in table fields, it will get an unknown field error in GroupByRset, - // so no need to check return value. - selectList.CloneHiddenField(name, tableFields) + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + if len(idx) > 1 { + return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i) + } - found = true - break - } + i.ReferScope = e.ReferScope + i.ReferIndex = e.ReferIndex + return i, true, nil + } - if !found { - return errors.Errorf("Unknown column '%s' in 'having clause'", name) - } + return i, false, nil +} + +func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression.Ident, bool, error) { + index, err := checkIdent(i, v.selectList, havingClause) + if err != nil { + return i, false, errors.Trace(err) + } + + if index >= 0 { + // identifier references a select field. use it directly. + // e,g. select c1 as c2 from t order by c2, here c2 references c1. + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, true, nil + } + + // we may meet this select c1 as c2 from t having c1, so we must check origin field name too. + for index := 0; index < v.selectList.HiddenFieldOffset; index++ { + e := castIdent(v.selectList.Fields[index].Expr) + if e == nil { + // not identifier + continue + } + + if !field.CheckFieldsEqual(e.L, i.L) { + // not same, continue + continue + } + + if field.IsQualifiedName(i.L) { + // qualified name, no need check + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, true, nil + } + + // we may meet select t1.c as a from t1, t2 having c, must check ambiguous here + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + if len(idx) > 1 { + return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i) + } + + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, true, nil + } + + return i, false, nil +} + +func (v *havingVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { + // Having has the most complex rule for ambiguous and identifer reference check. + + // Having ambiguous rule: + // select c1 as a, c2 as a from t having a is ambiguous + // select c1 as a, c2 as a from t having a + 1 is ambiguous + // select c1 as c2, c2 from t having c2 is ambiguous + // select c1 as c2, c2 from t having c2 + 1 is ambiguous + + // The identifier in having must exist in group by, select list and outer query, so select c1 from t having c2 is wrong, + // the idenfitier will first try to find the reference in group by, then in select list and finally in outer query. + + // Having identifier reference check + // select c1 from t group by c2 having c2, c2 in having references group by c2. + // select c1 as c2 from t having c2, c2 in having references c1 in select field. + // select c1 as c2 from t group by c2 having c2, c2 in having references group by c2. + // select c1 as c2 from t having sum(c2), c2 in having sum(c2) references t.c2 in from table. + // select c1 as c2 from t having sum(c2) + c2, c2 in having sum(c2) references t.c2 in from table, but another c2 references c1 in select list. + // select c1 as c2 from t group by c2 having sum(c2) + c2, c2 in having sum(c2) references t.c2 in from table, another c2 references c2 in group by. + // select c1 as a from t having a, a references c1 in select field. + // select c1 as a from t having sum(a), a in order by sum(a) references c1 in select field. + // select c1 as c2 from t having c1, c1 in having references c1 in select list. + + // TODO: unify VisitIdent for group by and order by visitor, they are little different. + + if v.inAggregate { + return v.visitIdentInAggregate(i) + } + + var ( + err error + ok bool + ) + + // first, check in group by + i, ok, err = v.checkIdentInGroupBy(i) + if err != nil || ok { + return i, errors.Trace(err) + } + + // then in select list + i, ok, err = v.checkIdentInSelectList(i) + if err != nil || ok { + return i, errors.Trace(err) + } + + // TODO: check in out query + + return i, errors.Errorf("Unknown column '%s' in 'having clause'", i) +} + +func (v *havingVisitor) VisitPosition(p *expression.Position) (expression.Expression, error) { + n := p.N + + // we have added order by to hidden field before, now visit it here. + expr := v.selectList.Fields[n-1].Expr + e, err := expr.Accept(v) + if err != nil { + return nil, errors.Trace(err) + } + v.selectList.Fields[n-1].Expr = e + + return p, nil +} + +func (v *havingVisitor) VisitCall(c *expression.Call) (expression.Expression, error) { + isAggregate, err := expression.IsAggregateFunc(c.F) + if err != nil { + return nil, errors.Trace(err) + } + + v.inAggregate = isAggregate + + for i, e := range c.Args { + c.Args[i], err = e.Accept(v) + if err != nil { + return nil, errors.Trace(err) } } - return nil + v.inAggregate = false + + return c, nil } // Plan gets HavingPlan. func (r *HavingRset) Plan(ctx context.Context) (plan.Plan, error) { + visitor := &havingVisitor{ + selectList: r.SelectList, + groupBy: r.GroupBy, + inAggregate: false, + } + visitor.BaseVisitor.V = visitor + + e, err := r.Expr.Accept(visitor) + if err != nil { + return nil, errors.Trace(err) + } + + r.Expr = e + return &plans.HavingPlan{Src: r.Src, Expr: r.Expr, SelectList: r.SelectList}, nil } diff --git a/rset/rsets/orderby.go b/rset/rsets/orderby.go index c16daf5ba4..f71634373b 100644 --- a/rset/rsets/orderby.go +++ b/rset/rsets/orderby.go @@ -150,9 +150,7 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, } // TODO: check in out query - // TODO: return unknown field error, but now just return directly. - // Because this may reference outer query. - return i, nil + return i, errors.Errorf("Unknown column '%s' in 'order clause'", i) } func (v *orderByVisitor) VisitPosition(p *expression.Position) (expression.Expression, error) { diff --git a/stmt/stmts/select.go b/stmt/stmts/select.go index 1b0b48b16d..a7ecfc75ce 100644 --- a/stmt/stmts/select.go +++ b/stmt/stmts/select.go @@ -175,7 +175,7 @@ func (s *SelectStmt) Plan(ctx context.Context) (plan.Plan, error) { if s.Having != nil { // `having` may contain aggregate functions, and we will add this to hidden fields. - if err = s.Having.CheckAndUpdateSelectList(selectList, groupBy, r.GetFields()); err != nil { + if err = s.Having.CheckAggregate(selectList); err != nil { return nil, errors.Trace(err) } } @@ -209,7 +209,8 @@ func (s *SelectStmt) Plan(ctx context.Context) (plan.Plan, error) { if r, err = (&rsets.HavingRset{ Src: r, Expr: s.Expr, - SelectList: selectList}).Plan(ctx); err != nil { + SelectList: selectList, + GroupBy: groupBy}).Plan(ctx); err != nil { return nil, err } } From a260bbb6c83f52cf50c1caa3ebae4ccaa390df99 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 22:57:51 +0800 Subject: [PATCH 32/54] *: add more clause type --- plan/plans/select_list.go | 19 +------------------ rset/rsets/fields.go | 2 +- rset/rsets/helper.go | 19 ++++++++++++++++--- rset/rsets/where.go | 2 +- 4 files changed, 19 insertions(+), 23 deletions(-) diff --git a/plan/plans/select_list.go b/plan/plans/select_list.go index bacc7b729d..66704acdfb 100644 --- a/plan/plans/select_list.go +++ b/plan/plans/select_list.go @@ -124,24 +124,6 @@ func (s *SelectList) UpdateAggFields(expr expression.Expression) (expression.Exp return &expression.Position{N: index + 1, Name: name}, nil } -// CloneHiddenField checks and clones field and result field from table fields, -// and adds them to hidden field of select list. -func (s *SelectList) CloneHiddenField(name string, tableFields []*field.ResultField) bool { - // Check and add hidden field. - if field.ContainFieldName(name, tableFields, field.CheckFieldFlag) { - resultField, _ := field.CloneFieldByName(name, tableFields, field.CheckFieldFlag) - f := &field.Field{ - Expr: &expression.Ident{ - CIStr: resultField.ColumnInfo.Name, - }, - } - s.AddField(f, resultField) - return true - } - - return false -} - // CheckReferAmbiguous checks whether an identifier reference is ambiguous or not in select list. // e,g, "select c1 as a, c2 as a from t group by a" is ambiguous, // but "select c1 as a, c1 as a from t group by a" is not. @@ -235,6 +217,7 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie continue } + // TODO: use fromIdentVisitor to cleanup. var result *field.ResultField for _, name := range names { idx := field.GetResultFieldIndex(name, srcFields, field.DefaultFieldFlag) diff --git a/rset/rsets/fields.go b/rset/rsets/fields.go index 4df4726463..ffe6176409 100644 --- a/rset/rsets/fields.go +++ b/rset/rsets/fields.go @@ -40,7 +40,7 @@ type SelectFieldsRset struct { } func updateSelectFieldsRefer(selectList *plans.SelectList) error { - visitor := newFromIdentVisitor(selectList.FromFields) + visitor := newFromIdentVisitor(selectList.FromFields, fieldListClause) // we only fix un-hidden fields here, for hidden fields, it should be // handled in their own clause, in order by or having. diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index abea6bb44f..5f54d9f435 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -50,13 +50,22 @@ type clauseType int const ( noneClause clauseType = iota + onClause + whereClause groupByClause - orderByClause + fieldListClause havingClause + orderByClause ) func (clause clauseType) String() string { switch clause { + case onClause: + return "on clause" + case fieldListClause: + return "field list" + case whereClause: + return "where clause" case groupByClause: return "group statement" case orderByClause: @@ -138,24 +147,28 @@ func checkIdent(i *expression.Ident, selectList *plans.SelectList, clause clause type fromIdentVisitor struct { expression.BaseVisitor fromFields []*field.ResultField + clause clauseType } func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { idx := field.GetResultFieldIndex(i.L, v.fromFields, field.DefaultFieldFlag) - if len(idx) > 0 { + if len(idx) == 1 { i.ReferScope = expression.IdentReferFromTable i.ReferIndex = idx[0] return i, nil + } else if len(idx) > 1 { + return nil, errors.Errorf("Column '%s' in %s is ambiguous", i, v.clause) } // TODO: check in outer query return i, nil } -func newFromIdentVisitor(fromFields []*field.ResultField) *fromIdentVisitor { +func newFromIdentVisitor(fromFields []*field.ResultField, clause clauseType) *fromIdentVisitor { visitor := &fromIdentVisitor{} visitor.BaseVisitor.V = visitor visitor.fromFields = fromFields + visitor.clause = clause return visitor } diff --git a/rset/rsets/where.go b/rset/rsets/where.go index f32313ac8f..3a583fe010 100644 --- a/rset/rsets/where.go +++ b/rset/rsets/where.go @@ -180,7 +180,7 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl } func (r *WhereRset) updateWhereFieldsRefer() error { - visitor := newFromIdentVisitor(r.Src.GetFields()) + visitor := newFromIdentVisitor(r.Src.GetFields(), whereClause) e, err := r.Expr.Accept(visitor) if err != nil { From 655ed973be9428534fc04379b49bd0da21ba1856 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sat, 17 Oct 2015 22:58:07 +0800 Subject: [PATCH 33/54] *: update having and add test --- rset/rsets/having.go | 19 ++++++++++++++++--- tidb_test.go | 32 ++++++++++++++++++++++++++++---- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/rset/rsets/having.go b/rset/rsets/having.go index 2105bfd0fe..6119227a3f 100644 --- a/rset/rsets/having.go +++ b/rset/rsets/having.go @@ -143,6 +143,8 @@ func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression return i, true, nil } + lastIndex := -1 + var lastFieldName string // we may meet this select c1 as c2 from t having c1, so we must check origin field name too. for index := 0; index < v.selectList.HiddenFieldOffset; index++ { e := castIdent(v.selectList.Fields[index].Expr) @@ -163,9 +165,14 @@ func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression return i, true, nil } - // we may meet select t1.c as a from t1, t2 having c, must check ambiguous here - idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) - if len(idx) > 1 { + if lastIndex == -1 { + lastIndex = index + lastFieldName = e.L + continue + } + + // we may meet select t1.c as a, t2.c as b from t1, t2 having c, must check ambiguous here + if !field.CheckFieldsEqual(lastFieldName, e.L) { return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i) } @@ -174,6 +181,12 @@ func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression return i, true, nil } + if lastIndex != -1 { + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = lastIndex + return i, true, nil + } + return i, false, nil } diff --git a/tidb_test.go b/tidb_test.go index 120a64e719..31391ce0a0 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -1177,11 +1177,35 @@ func (s *testSessionSuite) TestHaving(c *C) { store := newStore(c, s.dbName) se := newSession(c, store, s.dbName) mustExecSQL(c, se, "drop table if exists t") - mustExecSQL(c, se, "create table t (c1 int, c2 int)") - mustExecSQL(c, se, "insert into t values (1,2), (2, 1)") + mustExecSQL(c, se, "create table t (c1 int, c2 int, c3 int)") + mustExecSQL(c, se, "insert into t values (1,2,3), (2, 3, 1), (3, 1, 2)") - mustExecMatch(c, se, "select sum(c1) from t group by c1 having sum(c1)", [][]interface{}{{1}, {2}}) - mustExecMatch(c, se, "select sum(c1) - 1 from t group by c1 having sum(c1) - 1", [][]interface{}{{1}}) + mustExecMatch(c, se, "select c1 as c2, c3 from t having c2 = 2", [][]interface{}{{2, 1}}) + mustExecMatch(c, se, "select c1 as c2, c3 from t group by c2 having c2 = 2;", [][]interface{}{{1, 3}}) + mustExecMatch(c, se, "select c1 as c2, c3 from t group by c2 having sum(c2) = 2;", [][]interface{}{{1, 3}}) + mustExecMatch(c, se, "select c1 as c2, c3 from t group by c3 having sum(c2) = 2;", [][]interface{}{{1, 3}}) + mustExecMatch(c, se, "select c1 as c2, c3 from t group by c3 having sum(0) + c2 = 2;", [][]interface{}{{2, 1}}) + mustExecMatch(c, se, "select c1 as a from t having c1 = 1;", [][]interface{}{{1}}) + mustExecMatch(c, se, "select t.c1 from t having c1 = 1;", [][]interface{}{{1}}) + mustExecMatch(c, se, "select a.c1 from t as a having c1 = 1;", [][]interface{}{{1}}) + mustExecMatch(c, se, "select c1 as a from t group by c3 having sum(a) = 1;", [][]interface{}{{1}}) + mustExecMatch(c, se, "select c1 as a from t group by c3 having sum(a) + a = 2;", [][]interface{}{{1}}) + mustExecMatch(c, se, "select a.c1 as c, a.c1 as d from t as a, t as b having c1 = 1 limit 1;", [][]interface{}{{1, 1}}) + + mustExecMatch(c, se, "select sum(c1) from t group by c1 having sum(c1)", [][]interface{}{{1}, {2}, {3}}) + mustExecMatch(c, se, "select sum(c1) - 1 from t group by c1 having sum(c1) - 1", [][]interface{}{{1}, {2}}) + + mustExecFailed(c, se, "select c1 from t having c2") + mustExecFailed(c, se, "select c1 from t having c2 + 1") + mustExecFailed(c, se, "select c1 from t group by c2 + 1 having c2") + mustExecFailed(c, se, "select c1 from t group by c2 + 1 having c2 + 1") + mustExecFailed(c, se, "select c1 as c2, c2 from t having c2") + mustExecFailed(c, se, "select c1 as c2, c2 from t having c2 + 1") + mustExecFailed(c, se, "select c1 as a, c2 as a from t having a") + mustExecFailed(c, se, "select c1 as a, c2 as a from t having a + 1") + mustExecFailed(c, se, "select c1 + 1 from t having c1") + mustExecFailed(c, se, "select c1 + 1 from t having c1 + 1") + mustExecFailed(c, se, "select a.c1 as c, b.c1 as d from t as a, t as b having c1") } func newSession(c *C, store kv.Storage, dbName string) Session { From c464ae765aa2a6138a0feb711725097c0f135f33 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sun, 18 Oct 2015 08:55:31 +0800 Subject: [PATCH 34/54] *: on support ident refer --- plan/plans/join.go | 38 ++++++++++++++++++++++---------------- rset/rsets/helper.go | 5 +++++ rset/rsets/join.go | 11 +++++++++++ 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/plan/plans/join.go b/plan/plans/join.go index cd04caa6cd..0ac18e75a6 100644 --- a/plan/plans/join.go +++ b/plan/plans/join.go @@ -155,6 +155,7 @@ func (r *JoinPlan) Next(ctx context.Context) (row *plan.Row, err error) { } func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error) { + visitor := newJoinIdentVisitor(r.Left.GetFields(), 0) for { if r.cursor < len(r.matchedRows) { row = r.matchedRows[r.cursor] @@ -167,7 +168,7 @@ func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error) return nil, errors.Trace(err) } tempExpr := r.On.Clone() - visitor := NewIdentEvalVisitor(r.Left.GetFields(), leftRow.Data) + visitor.row = leftRow.Data _, err = tempExpr.Accept(visitor) if err != nil { return nil, errors.Trace(err) @@ -189,6 +190,7 @@ func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error) } func (r *JoinPlan) nextRightJoin(ctx context.Context) (row *plan.Row, err error) { + visitor := newJoinIdentVisitor(r.Right.GetFields(), len(r.Left.GetFields())) for { if r.cursor < len(r.matchedRows) { row = r.matchedRows[r.cursor] @@ -202,7 +204,7 @@ func (r *JoinPlan) nextRightJoin(ctx context.Context) (row *plan.Row, err error) } tempExpr := r.On.Clone() - visitor := NewIdentEvalVisitor(r.Right.GetFields(), rightRow.Data) + visitor.row = rightRow.Data _, err = tempExpr.Accept(visitor) if err != nil { return nil, errors.Trace(err) @@ -251,8 +253,8 @@ func (r *JoinPlan) findMatchedRows(ctx context.Context, row *plan.Row, p plan.Pl } else { joined = append(append(joined, row.Data...), cmpRow.Data...) } - r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - return GetIdentValue(name, r.Fields, joined, field.DefaultFieldFlag) + r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + return joined[index], nil } var b bool b, err = expression.EvalBoolExpr(ctx, r.On, r.evalArgs) @@ -280,6 +282,7 @@ func (r *JoinPlan) findMatchedRows(ctx context.Context, row *plan.Row, p plan.Pl } func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error) { + visitor := newJoinIdentVisitor(r.Left.GetFields(), 0) for { if r.curRow == nil { r.curRow, err = r.Left.Next(ctx) @@ -292,7 +295,7 @@ func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error) if r.On != nil { tempExpr := r.On.Clone() - visitor := NewIdentEvalVisitor(r.Left.GetFields(), r.curRow.Data) + visitor.row = r.curRow.Data _, err = tempExpr.Accept(visitor) if err != nil { return nil, errors.Trace(err) @@ -323,8 +326,8 @@ func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error) joinedRow = append(append(joinedRow, r.curRow.Data...), rightRow.Data...) if r.On != nil { - r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - return GetIdentValue(name, r.Fields, joinedRow, field.DefaultFieldFlag) + r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + return joinedRow[index], nil } b, err := expression.EvalBoolExpr(ctx, r.On, r.evalArgs) @@ -356,32 +359,35 @@ func (r *JoinPlan) Close() error { return r.Left.Close() } -// IdentEvalVisitor converts Ident expression to value expression. -type IdentEvalVisitor struct { +// joinIdentVisitor converts Ident expression to value expression in ON expression. +type joinIdentVisitor struct { expression.BaseVisitor fields []*field.ResultField row []interface{} + offset int } -// NewIdentEvalVisitor creates a new IdentEvalVisitor. -func NewIdentEvalVisitor(fields []*field.ResultField, row []interface{}) *IdentEvalVisitor { - iev := &IdentEvalVisitor{fields: fields, row: row} +// newJoinIdentVisitor creates a new joinIdentVisitor. +func newJoinIdentVisitor(fields []*field.ResultField, offset int) *joinIdentVisitor { + iev := &joinIdentVisitor{fields: fields, offset: offset} iev.BaseVisitor.V = iev return iev } // VisitIdent implements Visitor interface. -func (iev *IdentEvalVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { - v, err := GetIdentValue(i.L, iev.fields, iev.row, field.CheckFieldFlag) - if err != nil { +func (iev *joinIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { + // the row here may be just left part or right part, but identifier may reference another part, + // so here we must check its reference index validation. + if i.ReferIndex < iev.offset || i.ReferIndex >= iev.offset+len(iev.row) { return i, nil } + v := iev.row[i.ReferIndex-iev.offset] return expression.Value{Val: v}, nil } // VisitBinaryOperation swaps the right side identifier to left side if left side expression is static. // So it can be used in index plan. -func (iev *IdentEvalVisitor) VisitBinaryOperation(binop *expression.BinaryOperation) (expression.Expression, error) { +func (iev *joinIdentVisitor) VisitBinaryOperation(binop *expression.BinaryOperation) (expression.Expression, error) { var err error binop.L, err = binop.L.Accept(iev) if err != nil { diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index 5f54d9f435..9cc3d465e7 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -160,6 +160,11 @@ func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expressio return nil, errors.Errorf("Column '%s' in %s is ambiguous", i, v.clause) } + if v.clause == onClause { + // on clause can't check outer query. + return nil, errors.Errorf("Unknown column '%s' in '%s'", i, v.clause) + } + // TODO: check in outer query return i, nil } diff --git a/rset/rsets/join.go b/rset/rsets/join.go index e358a27dc7..97e7c8e583 100644 --- a/rset/rsets/join.go +++ b/rset/rsets/join.go @@ -141,6 +141,17 @@ func (r *JoinRset) buildJoinPlan(ctx context.Context, p *plans.JoinPlan, s *Join p.Fields = append(p.Fields, rightFields...) } + if p.On != nil { + visitor := newFromIdentVisitor(p.Fields, onClause) + + e, err := p.On.Accept(visitor) + if err != nil { + return errors.Trace(err) + } + + p.On = e + } + return nil } From ce3d56b47c11a898b605bf1c44dc57ca434e77c4 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sun, 18 Oct 2015 08:58:20 +0800 Subject: [PATCH 35/54] *: add test and update comment. --- rset/rsets/helper.go | 1 + tidb_test.go | 1 + 2 files changed, 2 insertions(+) diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index 9cc3d465e7..90208228e1 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -46,6 +46,7 @@ func castIdent(e expression.Expression) *expression.Ident { return i } +// TODO: export clause type and move to plan? type clauseType int const ( diff --git a/tidb_test.go b/tidb_test.go index 31391ce0a0..a4582a6b31 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -882,6 +882,7 @@ func (s *testSessionSuite) TestSelect(c *C) { c.Assert(err, IsNil) matches(c, rows, [][]interface{}{{1, nil, nil}, {2, 2, nil}}) + mustExecFailed(c, se, "select * from t1 left join t2 on t1.c1 = t3.c3 left join on t3 on t1.c1 = t2.c2") } func (s *testSessionSuite) TestSubQuery(c *C) { From 9ef6a5357a58f489252bc7bd8193c7cbfd9c2165 Mon Sep 17 00:00:00 2001 From: siddontang Date: Sun, 18 Oct 2015 09:52:43 +0800 Subject: [PATCH 36/54] *: Address comment. --- expression/ident_test.go | 2 +- plan/plans/groupby.go | 6 ------ plan/plans/join.go | 11 +++++------ 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/expression/ident_test.go b/expression/ident_test.go index 08463d7eee..6d3dfe3719 100644 --- a/expression/ident_test.go +++ b/expression/ident_test.go @@ -57,7 +57,7 @@ func (s *testIdentSuite) TestIdent(c *C) { c.Assert(v, Equals, 1) delete(m, ExprEvalIdentFunc) - e.ReferScope = 1 + e.ReferScope = IdentReferSelectList e.ReferIndex = 1 m[ExprEvalIdentReferFunc] = func(string, int, int) (interface{}, error) { return 2, nil diff --git a/plan/plans/groupby.go b/plan/plans/groupby.go index fee4929b2f..14b6c422ed 100644 --- a/plan/plans/groupby.go +++ b/plan/plans/groupby.go @@ -135,12 +135,6 @@ func (r *GroupByDefaultPlan) evalNoneAggFields(ctx context.Context, out []interf return out[index], nil } - // // TODO: remove following getting later. - // v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) - // if err == nil { - // return v, nil - // } - // try to find in outer query return getIdentValueFromOuterQuery(ctx, name) } diff --git a/plan/plans/join.go b/plan/plans/join.go index 0ac18e75a6..bcea22bad0 100644 --- a/plan/plans/join.go +++ b/plan/plans/join.go @@ -155,7 +155,7 @@ func (r *JoinPlan) Next(ctx context.Context) (row *plan.Row, err error) { } func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error) { - visitor := newJoinIdentVisitor(r.Left.GetFields(), 0) + visitor := newJoinIdentVisitor(0) for { if r.cursor < len(r.matchedRows) { row = r.matchedRows[r.cursor] @@ -190,7 +190,7 @@ func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error) } func (r *JoinPlan) nextRightJoin(ctx context.Context) (row *plan.Row, err error) { - visitor := newJoinIdentVisitor(r.Right.GetFields(), len(r.Left.GetFields())) + visitor := newJoinIdentVisitor(len(r.Left.GetFields())) for { if r.cursor < len(r.matchedRows) { row = r.matchedRows[r.cursor] @@ -282,7 +282,7 @@ func (r *JoinPlan) findMatchedRows(ctx context.Context, row *plan.Row, p plan.Pl } func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error) { - visitor := newJoinIdentVisitor(r.Left.GetFields(), 0) + visitor := newJoinIdentVisitor(0) for { if r.curRow == nil { r.curRow, err = r.Left.Next(ctx) @@ -362,14 +362,13 @@ func (r *JoinPlan) Close() error { // joinIdentVisitor converts Ident expression to value expression in ON expression. type joinIdentVisitor struct { expression.BaseVisitor - fields []*field.ResultField row []interface{} offset int } // newJoinIdentVisitor creates a new joinIdentVisitor. -func newJoinIdentVisitor(fields []*field.ResultField, offset int) *joinIdentVisitor { - iev := &joinIdentVisitor{fields: fields, offset: offset} +func newJoinIdentVisitor(offset int) *joinIdentVisitor { + iev := &joinIdentVisitor{offset: offset} iev.BaseVisitor.V = iev return iev } From 16836344a5f51c8cc8c254f2aae6f76a8a89536c Mon Sep 17 00:00:00 2001 From: siddontang Date: Sun, 18 Oct 2015 14:33:50 +0800 Subject: [PATCH 37/54] *: fix aggregate arg check panic. --- expression/helper.go | 4 ++-- tidb_test.go | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/expression/helper.go b/expression/helper.go index 0f97ddbcdc..e0d8c5b639 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -152,7 +152,6 @@ func (v *mentionedAggregateFuncsVisitor) VisitCall(c *Call) (Expression, error) v.exprs = append(v.exprs, c) } - // aggregate function can't use aggregate function as the arg. n := len(v.exprs) for _, e := range c.Args { _, err := e.Accept(v) @@ -161,7 +160,8 @@ func (v *mentionedAggregateFuncsVisitor) VisitCall(c *Call) (Expression, error) } } - if len(v.exprs) != n { + if isAggregate && len(v.exprs) != n { + // aggregate function can't use aggregate function as the arg. // here means we have aggregate function in arg. return nil, errors.Errorf("Invalid use of group function") } diff --git a/tidb_test.go b/tidb_test.go index a4582a6b31..43430e4890 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -1141,6 +1141,7 @@ func (s *testSessionSuite) TestGroupBy(c *C) { mustExecSQL(c, se, "drop table if exists t") mustExecSQL(c, se, "create table t (c1 int, c2 int)") mustExecSQL(c, se, "insert into t values (1,1), (2,2), (1,2), (1,3)") + mustExecMatch(c, se, "select nullif (count(*), 2);", [][]interface{}{{1}}) mustExecMatch(c, se, "select c1 as c2, c2 from t group by c2 + 1", [][]interface{}{{1, 1}, {2, 2}, {1, 3}}) mustExecMatch(c, se, "select c1 as c2, count(c1) from t group by c2", [][]interface{}{{1, 1}, {2, 2}, {1, 1}}) From 7976510b791ccf9a4bc4e0c993b9d1b93a25593f Mon Sep 17 00:00:00 2001 From: siddontang Date: Sun, 18 Oct 2015 16:09:08 +0800 Subject: [PATCH 38/54] *: CheckReferAmbiguous -> GetIndex and update --- plan/plans/select_list.go | 34 ++++++++++++++++++++-------------- plan/plans/select_list_test.go | 10 +++++----- rset/rsets/having.go | 2 +- rset/rsets/helper.go | 30 +++++++----------------------- tidb_test.go | 3 +++ 5 files changed, 36 insertions(+), 43 deletions(-) diff --git a/plan/plans/select_list.go b/plan/plans/select_list.go index 66704acdfb..bb44f32e03 100644 --- a/plan/plans/select_list.go +++ b/plan/plans/select_list.go @@ -124,35 +124,41 @@ func (s *SelectList) UpdateAggFields(expr expression.Expression) (expression.Exp return &expression.Position{N: index + 1, Name: name}, nil } -// CheckReferAmbiguous checks whether an identifier reference is ambiguous or not in select list. +// GetIndex tries to find the index where contains the indentifier. +// It checks whether an identifier reference is ambiguous or not in select list. // e,g, "select c1 as a, c2 as a from t group by a" is ambiguous, // but "select c1 as a, c1 as a from t group by a" is not. -// For MySQL "select c1 as a, c2 + 1 as a from t group by a" is not ambiguous too, -// so we will only check identifier too. -// If no ambiguous, nil means expr refers none in select list, else an indices for fields have the same name in select list returns. -func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) ([]int, error) { +// "select c1 as a, c2 + 1 as a from t group by a" is not ambiguous too, +// If no ambiguous, -1 means expr refers none in select list, else an index for first match. +// GetIndex will break the check when finding first matching which is not an indentifier, +// or an index for an identifier field in the end, -1 means none found. +func (s *SelectList) GetIndex(expr expression.Expression) (int, error) { if _, ok := expr.(*expression.Ident); !ok { - return nil, nil + return -1, nil } name := expr.String() if field.IsQualifiedName(name) { // name is qualified, no need to check - return nil, nil + return -1, nil } + // select c1 as a, 1 as a, c2 as a from t order by a is not ambiguous. + // select c1 as a, c2 as a from t order by a is ambiguous. + // select 1 as a, c1 as a from t order by a is not ambiguous. + // select c1 as a, sum(c1) as a from t group by a is error. + // select c1 as a, 1 as a, sum(c1) as a from t group by a is not error. + // so we will break the check if matching a none identifier field. lastIndex := -1 - var idx []int // only check origin select list, no hidden field. for i := 0; i < s.HiddenFieldOffset; i++ { if !strings.EqualFold(s.ResultFields[i].Name, name) { continue } - idx = append(idx, i) if _, ok := s.Fields[i].Expr.(*expression.Ident); !ok { - // not identfier, no check - continue + // not identfier, return directly. + return i, nil } if lastIndex == -1 { @@ -164,18 +170,18 @@ func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) ([]int, err // check origin name, e,g. "select c1 as c2, c2 from t group by c2" is ambiguous. if s.ResultFields[i].ColumnInfo.Name.L != s.ResultFields[lastIndex].ColumnInfo.Name.L { - return nil, errors.Errorf("refer %s is ambiguous", expr) + return -1, errors.Errorf("refer %s is ambiguous", expr) } // check table name, e.g, "select t.c1, c1 from t group by c1" is not ambiguous. if s.ResultFields[i].TableName != s.ResultFields[lastIndex].TableName { - return nil, errors.Errorf("refer %s is ambiguous", expr) + return -1, errors.Errorf("refer %s is ambiguous", expr) } // TODO: check database name if possible. } - return idx, nil + return lastIndex, nil } // ResolveSelectList gets fields and result fields from selectFields and srcFields, diff --git a/plan/plans/select_list_test.go b/plan/plans/select_list_test.go index ffef1a0981..27468ef91d 100644 --- a/plan/plans/select_list_test.go +++ b/plan/plans/select_list_test.go @@ -82,11 +82,11 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) { Fields []pair Name string Err bool - Index []int + Index int }{ - {[]pair{{"id", ""}}, "id", false, []int{0}}, - {[]pair{{"id", "a"}, {"name", "a"}}, "a", true, nil}, - {[]pair{{"id", "a"}, {"name", "a"}}, "id", false, nil}, + {[]pair{{"id", ""}}, "id", false, 0}, + {[]pair{{"id", "a"}, {"name", "a"}}, "a", true, -1}, + {[]pair{{"id", "a"}, {"name", "a"}}, "id", false, -1}, } for _, t := range tbl { @@ -96,7 +96,7 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) { sl, err := plans.ResolveSelectList(fs, rs) c.Assert(err, IsNil) - idx, err := sl.CheckReferAmbiguous(&expression.Ident{ + idx, err := sl.GetIndex(&expression.Ident{ CIStr: model.NewCIStr(t.Name), }) diff --git a/rset/rsets/having.go b/rset/rsets/having.go index 6119227a3f..a678ada1c4 100644 --- a/rset/rsets/having.go +++ b/rset/rsets/having.go @@ -137,7 +137,7 @@ func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression if index >= 0 { // identifier references a select field. use it directly. - // e,g. select c1 as c2 from t order by c2, here c2 references c1. + // e,g. select c1 as c2 from t having c2, here c2 references c1. i.ReferScope = expression.IdentReferSelectList i.ReferIndex = index return i, true, nil diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index 90208228e1..ab9d41c058 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -110,37 +110,21 @@ func castPosition(e expression.Expression, selectList *plans.SelectList, clause } func checkIdent(i *expression.Ident, selectList *plans.SelectList, clause clauseType) (int, error) { - idx, err := selectList.CheckReferAmbiguous(i) + index, err := selectList.GetIndex(i) if err != nil { return -1, errors.Errorf("Column '%s' in %s is ambiguous", i, clause) - } else if len(idx) == 0 { + } else if index == -1 { return -1, nil } - // this identifier may reference multi fields. - // e.g, select c1 as a, c2 + 1 as a from t group by a, - // we will use the first one which is not an identifer. - // so, for select c1 as a, c2 + 1 as a from t group by a, we will use c2 + 1. - - useIndex := 0 - found := false - for _, index := range idx { - if clause == groupByClause { - // group by can not reference aggregate fields - if _, ok := selectList.AggFields[index]; ok { - return -1, errors.Errorf("Reference '%s' not supported (reference to group function)", i) - } - } - - if !found { - if castIdent(selectList.Fields[index].Expr) == nil { - useIndex = index - found = true - } + if clause == groupByClause { + // group by can not reference aggregate fields + if _, ok := selectList.AggFields[index]; ok { + return -1, errors.Errorf("Reference '%s' not supported (reference to group function)", i) } } - return idx[useIndex], nil + return index, nil } // fromIdentVisitor can only handle identifier which reference FROM table or outer query. diff --git a/tidb_test.go b/tidb_test.go index 43430e4890..937ce38999 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -1142,6 +1142,9 @@ func (s *testSessionSuite) TestGroupBy(c *C) { mustExecSQL(c, se, "create table t (c1 int, c2 int)") mustExecSQL(c, se, "insert into t values (1,1), (2,2), (1,2), (1,3)") mustExecMatch(c, se, "select nullif (count(*), 2);", [][]interface{}{{1}}) + mustExecMatch(c, se, "select 1 as a, sum(c1) as a from t group by a", [][]interface{}{{1, 5}}) + mustExecMatch(c, se, "select c1 as a, 1 as a, sum(c1) as a from t group by a", [][]interface{}{{1, 1, 5}}) + mustExecMatch(c, se, "select c1 as a, 1 as a, c2 as a from t group by a;", [][]interface{}{{1, 1, 1}}) mustExecMatch(c, se, "select c1 as c2, c2 from t group by c2 + 1", [][]interface{}{{1, 1}, {2, 2}, {1, 3}}) mustExecMatch(c, se, "select c1 as c2, count(c1) from t group by c2", [][]interface{}{{1, 1}, {2, 2}, {1, 1}}) From a70ff17b1a70f2125f9fd85aa97f713beabfcb40 Mon Sep 17 00:00:00 2001 From: siddontang Date: Mon, 19 Oct 2015 09:34:17 +0800 Subject: [PATCH 39/54] rsets: Address comment. --- rset/rsets/having.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rset/rsets/having.go b/rset/rsets/having.go index a678ada1c4..acbca24db6 100644 --- a/rset/rsets/having.go +++ b/rset/rsets/having.go @@ -199,8 +199,8 @@ func (v *havingVisitor) VisitIdent(i *expression.Ident) (expression.Expression, // select c1 as c2, c2 from t having c2 is ambiguous // select c1 as c2, c2 from t having c2 + 1 is ambiguous - // The identifier in having must exist in group by, select list and outer query, so select c1 from t having c2 is wrong, - // the idenfitier will first try to find the reference in group by, then in select list and finally in outer query. + // The identifier in having must exist in group by, select list or outer query, so select c1 from t having c2 is wrong, + // the identifier will first try to find the reference in group by, then in select list and finally in outer query. // Having identifier reference check // select c1 from t group by c2 having c2, c2 in having references group by c2. From 129f69292b2d9b3a6d1d1a39539507ae600841a1 Mon Sep 17 00:00:00 2001 From: siddontang Date: Mon, 19 Oct 2015 11:42:59 +0800 Subject: [PATCH 40/54] *: fix having aggregate bug --- rset/rsets/having.go | 14 ++++++++++++-- tidb_test.go | 2 ++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/rset/rsets/having.go b/rset/rsets/having.go index acbca24db6..8b1ee57cef 100644 --- a/rset/rsets/having.go +++ b/rset/rsets/having.go @@ -261,7 +261,15 @@ func (v *havingVisitor) VisitCall(c *expression.Call) (expression.Expression, er return nil, errors.Trace(err) } - v.inAggregate = isAggregate + if v.inAggregate && isAggregate { + // aggregate function can't contain aggregate function + return nil, errors.Errorf("Invalid use of group function") + } + + if isAggregate { + // set true to let outer know we are in aggregate function. + v.inAggregate = true + } for i, e := range c.Args { c.Args[i], err = e.Accept(v) @@ -270,7 +278,9 @@ func (v *havingVisitor) VisitCall(c *expression.Call) (expression.Expression, er } } - v.inAggregate = false + if isAggregate { + v.inAggregate = false + } return c, nil } diff --git a/tidb_test.go b/tidb_test.go index 937ce38999..b9fdfd661c 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -1199,6 +1199,7 @@ func (s *testSessionSuite) TestHaving(c *C) { mustExecMatch(c, se, "select sum(c1) from t group by c1 having sum(c1)", [][]interface{}{{1}, {2}, {3}}) mustExecMatch(c, se, "select sum(c1) - 1 from t group by c1 having sum(c1) - 1", [][]interface{}{{1}, {2}}) + mustExecMatch(c, se, "select 1 from t group by c1 having sum(abs(c2 + c3)) = c1", [][]interface{}{{1}}) mustExecFailed(c, se, "select c1 from t having c2") mustExecFailed(c, se, "select c1 from t having c2 + 1") @@ -1211,6 +1212,7 @@ func (s *testSessionSuite) TestHaving(c *C) { mustExecFailed(c, se, "select c1 + 1 from t having c1") mustExecFailed(c, se, "select c1 + 1 from t having c1 + 1") mustExecFailed(c, se, "select a.c1 as c, b.c1 as d from t as a, t as b having c1") + mustExecFailed(c, se, "select 1 from t having sum(avg(c1))") } func newSession(c *C, store kv.Storage, dbName string) Session { From fb3a51d834bf54a901e0da850ade88d2181ec95a Mon Sep 17 00:00:00 2001 From: siddontang Date: Mon, 19 Oct 2015 12:51:53 +0800 Subject: [PATCH 41/54] *: checkIdent -> checkIdentAmbiguous --- plan/plans/select_list.go | 7 +++---- plan/plans/select_list_test.go | 2 +- rset/rsets/groupby.go | 9 +++++++-- rset/rsets/having.go | 4 ++-- rset/rsets/helper.go | 11 ++--------- rset/rsets/orderby.go | 6 +++--- stmt/stmts/select.go | 2 +- tidb_test.go | 1 + 8 files changed, 20 insertions(+), 22 deletions(-) diff --git a/plan/plans/select_list.go b/plan/plans/select_list.go index bb44f32e03..52b76fa26a 100644 --- a/plan/plans/select_list.go +++ b/plan/plans/select_list.go @@ -124,15 +124,14 @@ func (s *SelectList) UpdateAggFields(expr expression.Expression) (expression.Exp return &expression.Position{N: index + 1, Name: name}, nil } -// GetIndex tries to find the index where contains the indentifier. -// It checks whether an identifier reference is ambiguous or not in select list. +// CheckAmbiguous checks whether an identifier reference is ambiguous or not in select list. // e,g, "select c1 as a, c2 as a from t group by a" is ambiguous, // but "select c1 as a, c1 as a from t group by a" is not. // "select c1 as a, c2 + 1 as a from t group by a" is not ambiguous too, // If no ambiguous, -1 means expr refers none in select list, else an index for first match. -// GetIndex will break the check when finding first matching which is not an indentifier, +// CheckAmbiguous will break the check when finding first matching which is not an indentifier, // or an index for an identifier field in the end, -1 means none found. -func (s *SelectList) GetIndex(expr expression.Expression) (int, error) { +func (s *SelectList) CheckAmbiguous(expr expression.Expression) (int, error) { if _, ok := expr.(*expression.Ident); !ok { return -1, nil } diff --git a/plan/plans/select_list_test.go b/plan/plans/select_list_test.go index 27468ef91d..7db298422d 100644 --- a/plan/plans/select_list_test.go +++ b/plan/plans/select_list_test.go @@ -96,7 +96,7 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) { sl, err := plans.ResolveSelectList(fs, rs) c.Assert(err, IsNil) - idx, err := sl.GetIndex(&expression.Ident{ + idx, err := sl.CheckAmbiguous(&expression.Ident{ CIStr: model.NewCIStr(t.Name), }) diff --git a/rset/rsets/groupby.go b/rset/rsets/groupby.go index 2ac5328523..5c18aaea0b 100644 --- a/rset/rsets/groupby.go +++ b/rset/rsets/groupby.go @@ -59,7 +59,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, if v.rootIdent == i { // The group by is an identifier, we must check it first. - index, err = checkIdent(i, v.selectList, groupByClause) + index, err = checkIdentAmbiguous(i, v.selectList, groupByClause) if err != nil { return nil, errors.Trace(err) } @@ -75,7 +75,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, if v.rootIdent != i { // This identifier is the part of the group by, check ambiguous here. - index, err = checkIdent(i, v.selectList, groupByClause) + index, err = checkIdentAmbiguous(i, v.selectList, groupByClause) if err != nil { return nil, errors.Trace(err) } @@ -83,6 +83,11 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, // try to find in select list, we have got index using checkIdent before. if index >= 0 { + // group by can not reference aggregate fields + if _, ok := v.selectList.AggFields[index]; ok { + return nil, errors.Errorf("Reference '%s' not supported (reference to group function)", i) + } + // find in select list i.ReferScope = expression.IdentReferSelectList i.ReferIndex = index diff --git a/rset/rsets/having.go b/rset/rsets/having.go index 8b1ee57cef..be587ebcb0 100644 --- a/rset/rsets/having.go +++ b/rset/rsets/having.go @@ -74,7 +74,7 @@ func (v *havingVisitor) visitIdentInAggregate(i *expression.Ident) (expression.E } // check in select list. - index, err := checkIdent(i, v.selectList, havingClause) + index, err := checkIdentAmbiguous(i, v.selectList, havingClause) if err != nil { return i, errors.Trace(err) } @@ -130,7 +130,7 @@ func (v *havingVisitor) checkIdentInGroupBy(i *expression.Ident) (*expression.Id } func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression.Ident, bool, error) { - index, err := checkIdent(i, v.selectList, havingClause) + index, err := checkIdentAmbiguous(i, v.selectList, havingClause) if err != nil { return i, false, errors.Trace(err) } diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index ab9d41c058..91aff99379 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -109,21 +109,14 @@ func castPosition(e expression.Expression, selectList *plans.SelectList, clause return &expression.Position{N: position}, nil } -func checkIdent(i *expression.Ident, selectList *plans.SelectList, clause clauseType) (int, error) { - index, err := selectList.GetIndex(i) +func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clause clauseType) (int, error) { + index, err := selectList.CheckAmbiguous(i) if err != nil { return -1, errors.Errorf("Column '%s' in %s is ambiguous", i, clause) } else if index == -1 { return -1, nil } - if clause == groupByClause { - // group by can not reference aggregate fields - if _, ok := selectList.AggFields[index]; ok { - return -1, errors.Errorf("Reference '%s' not supported (reference to group function)", i) - } - } - return index, nil } diff --git a/rset/rsets/orderby.go b/rset/rsets/orderby.go index f71634373b..9412cc9faa 100644 --- a/rset/rsets/orderby.go +++ b/rset/rsets/orderby.go @@ -111,7 +111,7 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, if v.rootIdent == i { // The order by is an identifier, we must check it first. - index, err = checkIdent(i, v.selectList, orderByClause) + index, err = checkIdentAmbiguous(i, v.selectList, orderByClause) if err != nil { return nil, errors.Trace(err) } @@ -135,13 +135,13 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, if v.rootIdent != i { // This identifier is the part of the order by, check ambiguous here. - index, err = checkIdent(i, v.selectList, orderByClause) + index, err = checkIdentAmbiguous(i, v.selectList, orderByClause) if err != nil { return nil, errors.Trace(err) } } - // try to find in select list, we have got index using checkIdent before. + // try to find in select list, we have got index using checkIdentAmbiguous before. if index >= 0 { // find in select list i.ReferScope = expression.IdentReferSelectList diff --git a/stmt/stmts/select.go b/stmt/stmts/select.go index a7ecfc75ce..72b72594bd 100644 --- a/stmt/stmts/select.go +++ b/stmt/stmts/select.go @@ -188,7 +188,7 @@ func (s *SelectStmt) Plan(ctx context.Context) (plan.Plan, error) { } switch { - case !rsets.HasAggFields(selectList.Fields) && s.GroupBy == nil: + case len(selectList.AggFields) == 0 && s.GroupBy == nil: // If no group by and no aggregate functions, we will use SelectFieldsPlan. if r, err = (&rsets.SelectFieldsRset{Src: r, SelectList: selectList, diff --git a/tidb_test.go b/tidb_test.go index b9fdfd661c..f0962800c9 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -1145,6 +1145,7 @@ func (s *testSessionSuite) TestGroupBy(c *C) { mustExecMatch(c, se, "select 1 as a, sum(c1) as a from t group by a", [][]interface{}{{1, 5}}) mustExecMatch(c, se, "select c1 as a, 1 as a, sum(c1) as a from t group by a", [][]interface{}{{1, 1, 5}}) mustExecMatch(c, se, "select c1 as a, 1 as a, c2 as a from t group by a;", [][]interface{}{{1, 1, 1}}) + mustExecMatch(c, se, "select c1 as c2, sum(c1) as c2 from t group by c2;", [][]interface{}{{1, 1}, {2, 3}, {1, 1}}) mustExecMatch(c, se, "select c1 as c2, c2 from t group by c2 + 1", [][]interface{}{{1, 1}, {2, 2}, {1, 3}}) mustExecMatch(c, se, "select c1 as c2, count(c1) from t group by c2", [][]interface{}{{1, 1}, {2, 2}, {1, 1}}) From 83086b57a77200aa034238599ce3b4e5109add38 Mon Sep 17 00:00:00 2001 From: siddontang Date: Mon, 19 Oct 2015 14:11:41 +0800 Subject: [PATCH 42/54] *: Address comment. --- expression/helper.go | 12 +++++------- rset/rsets/groupby.go | 7 +++---- rset/rsets/having.go | 6 ++---- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/expression/helper.go b/expression/helper.go index e0d8c5b639..e38e889eb9 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -144,10 +144,8 @@ func newMentionedAggregateFuncsVisitor() *mentionedAggregateFuncsVisitor { } func (v *mentionedAggregateFuncsVisitor) VisitCall(c *Call) (Expression, error) { - isAggregate, err := IsAggregateFunc(c.F) - if err != nil { - return nil, errors.Trace(err) - } + isAggregate := IsAggregateFunc(c.F) + if isAggregate { v.exprs = append(v.exprs, c) } @@ -170,12 +168,12 @@ func (v *mentionedAggregateFuncsVisitor) VisitCall(c *Call) (Expression, error) } // IsAggregateFunc checks whether name is an aggregate function or not. -func IsAggregateFunc(name string) (bool, error) { +func IsAggregateFunc(name string) bool { f, ok := builtin.Funcs[strings.ToLower(name)] if !ok { - return false, errors.Errorf("unknown function %s", name) + return false } - return f.IsAggregate, nil + return f.IsAggregate } // MentionedColumns returns a list of names for Ident expression. diff --git a/rset/rsets/groupby.go b/rset/rsets/groupby.go index 5c18aaea0b..d95ce8f073 100644 --- a/rset/rsets/groupby.go +++ b/rset/rsets/groupby.go @@ -99,13 +99,12 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, } func (v *groupByVisitor) VisitCall(c *expression.Call) (expression.Expression, error) { - ok, err := expression.IsAggregateFunc(c.F) - if err != nil { - return nil, errors.Trace(err) - } else if ok { + ok := expression.IsAggregateFunc(c.F) + if ok { return nil, errors.Errorf("group by cannot contain aggregate function %s", c) } + var err error for i, e := range c.Args { c.Args[i], err = e.Accept(v) if err != nil { diff --git a/rset/rsets/having.go b/rset/rsets/having.go index be587ebcb0..be43f3cc02 100644 --- a/rset/rsets/having.go +++ b/rset/rsets/having.go @@ -256,10 +256,7 @@ func (v *havingVisitor) VisitPosition(p *expression.Position) (expression.Expres } func (v *havingVisitor) VisitCall(c *expression.Call) (expression.Expression, error) { - isAggregate, err := expression.IsAggregateFunc(c.F) - if err != nil { - return nil, errors.Trace(err) - } + isAggregate := expression.IsAggregateFunc(c.F) if v.inAggregate && isAggregate { // aggregate function can't contain aggregate function @@ -271,6 +268,7 @@ func (v *havingVisitor) VisitCall(c *expression.Call) (expression.Expression, er v.inAggregate = true } + var err error for i, e := range c.Args { c.Args[i], err = e.Accept(v) if err != nil { From cee3cec5067b75ab4383d85e9908b616d0e48489 Mon Sep 17 00:00:00 2001 From: siddontang Date: Mon, 19 Oct 2015 14:21:45 +0800 Subject: [PATCH 43/54] expression: Address comment. --- expression/helper.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/expression/helper.go b/expression/helper.go index e38e889eb9..91023c6006 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -169,6 +169,8 @@ func (v *mentionedAggregateFuncsVisitor) VisitCall(c *Call) (Expression, error) // IsAggregateFunc checks whether name is an aggregate function or not. func IsAggregateFunc(name string) bool { + // TODO: use switch defined aggregate name "sum", "count", etc... directly. + // Maybe we can remove builtin IsAggregate field later. f, ok := builtin.Funcs[strings.ToLower(name)] if !ok { return false From 1da15de1fe61852096bd7f8ef8e69bc10cfebf87 Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Mon, 19 Oct 2015 14:33:37 +0800 Subject: [PATCH 44/54] *: address comments. --- mysqldef/time.go | 4 ++-- parser/parser.y | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mysqldef/time.go b/mysqldef/time.go index 21c81455d9..e7f909b301 100644 --- a/mysqldef/time.go +++ b/mysqldef/time.go @@ -1091,12 +1091,12 @@ func extractSecondMicrosecond(format string) (int64, int64, int64, time.Duration return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) } - microsecond, err := strconv.ParseInt(alignFrac(fields[1], MaxFsp), 10, 64) + microseconds, err := strconv.ParseInt(alignFrac(fields[1], MaxFsp), 10, 64) if err != nil { return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) } - return 0, 0, 0, time.Duration(seconds)*time.Second + time.Duration(microsecond)*time.Microsecond, nil + return 0, 0, 0, time.Duration(seconds)*time.Second + time.Duration(microseconds)*time.Microsecond, nil } // Format is `MM:SS.FFFFFF`. diff --git a/parser/parser.y b/parser/parser.y index 758e334fcb..e3ca085133 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -1692,7 +1692,7 @@ UnReservedKeyword: | "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" NotKeywordToken: - "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" + "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" | "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS" | "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK" From 94612bc56b7371847cc3995764feac26a215adee Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Mon, 19 Oct 2015 14:42:26 +0800 Subject: [PATCH 45/54] stmt: fix go vet check error. --- stmt/stmts/grant.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/stmt/stmts/grant.go b/stmt/stmts/grant.go index 841ecb8769..d4dbcbab2a 100644 --- a/stmt/stmts/grant.go +++ b/stmt/stmts/grant.go @@ -305,7 +305,7 @@ func composeGlobalPrivUpdate(priv mysql.PrivilegeType) (string, error) { } col, ok := mysql.Priv2UserCol[priv] if !ok { - return "", errors.Errorf("Unknown priv: %s", priv) + return "", errors.Errorf("Unknown priv: %v", priv) } return fmt.Sprintf(`%s="Y"`, col), nil } @@ -317,7 +317,7 @@ func composeDBPrivUpdate(priv mysql.PrivilegeType) (string, error) { for _, p := range mysql.AllDBPrivs { v, ok := mysql.Priv2UserCol[p] if !ok { - return "", errors.Errorf("Unknown db privilege %s", priv) + return "", errors.Errorf("Unknown db privilege %v", priv) } strs = append(strs, fmt.Sprintf(`%s="Y"`, v)) } @@ -325,7 +325,7 @@ func composeDBPrivUpdate(priv mysql.PrivilegeType) (string, error) { } col, ok := mysql.Priv2UserCol[priv] if !ok { - return "", errors.Errorf("Unknown priv: %s", priv) + return "", errors.Errorf("Unknown priv: %v", priv) } return fmt.Sprintf(`%s="Y"`, col), nil } @@ -363,7 +363,7 @@ func composeTablePrivUpdate(ctx context.Context, priv mysql.PrivilegeType, name } p, ok := mysql.Priv2SetStr[priv] if !ok { - return "", errors.Errorf("Unknown priv: %s", priv) + return "", errors.Errorf("Unknown priv: %v", priv) } if len(currTablePriv) == 0 { newTablePriv = p @@ -406,7 +406,7 @@ func composeColumnPrivUpdate(ctx context.Context, priv mysql.PrivilegeType, name } p, ok := mysql.Priv2SetStr[priv] if !ok { - return "", errors.Errorf("Unknown priv: %s", priv) + return "", errors.Errorf("Unknown priv: %v", priv) } if len(currColumnPriv) == 0 { newColumnPriv = p From 29369dba8fcf09498ac0551ef4df4b3712ac9779 Mon Sep 17 00:00:00 2001 From: ngaut Date: Mon, 19 Oct 2015 16:26:43 +0800 Subject: [PATCH 46/54] *: Rename mysqldef to mysql --- bootstrap.go | 2 +- column/column.go | 2 +- column/column_test.go | 2 +- ddl/ddl.go | 2 +- ddl/ddl_test.go | 2 +- driver.go | 2 +- expression/binop.go | 2 +- expression/binop_test.go | 2 +- expression/builtin/groupby.go | 2 +- expression/builtin/groupby_test.go | 2 +- expression/builtin/string_test.go | 2 +- expression/builtin/time.go | 2 +- expression/builtin/time_test.go | 2 +- expression/cast.go | 2 +- expression/cast_test.go | 2 +- expression/date_add.go | 2 +- expression/date_add_test.go | 2 +- expression/extract.go | 2 +- expression/helper.go | 2 +- expression/helper_test.go | 2 +- expression/unary.go | 2 +- expression/unary_test.go | 2 +- expression/visitor_test.go | 2 +- field/field_test.go | 2 +- field/result_field.go | 2 +- field/result_field_test.go | 2 +- infoschema/infoschema_test.go | 4 ++-- {mysqldef => mysql}/bit.go | 2 +- {mysqldef => mysql}/bit_test.go | 2 +- {mysqldef => mysql}/charset.go | 2 +- {mysqldef => mysql}/const.go | 2 +- {mysqldef => mysql}/decimal.go | 2 +- {mysqldef => mysql}/decimal_test.go | 2 +- {mysqldef => mysql}/enum.go | 2 +- {mysqldef => mysql}/enum_test.go | 2 +- {mysqldef => mysql}/errcode.go | 2 +- {mysqldef => mysql}/errname.go | 2 +- {mysqldef => mysql}/error.go | 2 +- {mysqldef => mysql}/error_test.go | 2 +- {mysqldef => mysql}/fsp.go | 2 +- {mysqldef => mysql}/hex.go | 2 +- {mysqldef => mysql}/hex_test.go | 2 +- {mysqldef => mysql}/set.go | 2 +- {mysqldef => mysql}/set_test.go | 2 +- {mysqldef => mysql}/state.go | 2 +- {mysqldef => mysql}/time.go | 2 +- {mysqldef => mysql}/time_test.go | 2 +- {mysqldef => mysql}/type.go | 2 +- {mysqldef => mysql}/type_test.go | 2 +- {mysqldef => mysql}/util.go | 2 +- {mysqldef => mysql}/util_test.go | 2 +- parser/coldef/col_def.go | 2 +- parser/coldef/opt.go | 2 +- parser/parser.y | 2 +- parser/scanner.l | 2 +- plan/plans/final_test.go | 2 +- plan/plans/from_test.go | 2 +- plan/plans/index_test.go | 2 +- plan/plans/info.go | 2 +- plan/plans/join_test.go | 2 +- plan/plans/plans.go | 2 +- plan/plans/show.go | 2 +- plan/plans/show_test.go | 2 +- plan/plans/union_test.go | 2 +- session.go | 2 +- sessionctx/variable/session.go | 2 +- stmt/stmts/account_manage.go | 2 +- stmt/stmts/create_test.go | 2 +- stmt/stmts/grant.go | 4 ++-- stmt/stmts/grant_test.go | 2 +- stmt/stmts/insert.go | 2 +- stmt/stmts/stmt_helper.go | 2 +- stmt/stmts/transaction.go | 2 +- stmt/stmts/update.go | 2 +- table/tables/tables.go | 2 +- table/tables/tables_test.go | 2 +- tidb-server/server/conn.go | 2 +- tidb-server/server/conn_stmt.go | 2 +- tidb-server/server/driver_tidb.go | 2 +- tidb-server/server/packetio.go | 2 +- tidb-server/server/server.go | 2 +- tidb-server/server/util.go | 2 +- tidb-server/server/util_test.go | 2 +- tidb_test.go | 2 +- util/codec/codec.go | 2 +- util/codec/codec_test.go | 2 +- util/codec/decimal.go | 2 +- util/codec/decimal_test.go | 2 +- util/types/compare.go | 2 +- util/types/compare_test.go | 2 +- util/types/convert.go | 2 +- util/types/convert_test.go | 2 +- util/types/etc.go | 2 +- util/types/etc_test.go | 2 +- util/types/field_type.go | 2 +- util/types/field_type_test.go | 2 +- 96 files changed, 98 insertions(+), 98 deletions(-) rename {mysqldef => mysql}/bit.go (99%) rename {mysqldef => mysql}/bit_test.go (98%) rename {mysqldef => mysql}/charset.go (99%) rename {mysqldef => mysql}/const.go (99%) rename {mysqldef => mysql}/decimal.go (99%) rename {mysqldef => mysql}/decimal_test.go (99%) rename {mysqldef => mysql}/enum.go (98%) rename {mysqldef => mysql}/enum_test.go (98%) rename {mysqldef => mysql}/errcode.go (99%) rename {mysqldef => mysql}/errname.go (99%) rename {mysqldef => mysql}/error.go (99%) rename {mysqldef => mysql}/error_test.go (98%) rename {mysqldef => mysql}/fsp.go (99%) rename {mysqldef => mysql}/hex.go (99%) rename {mysqldef => mysql}/hex_test.go (98%) rename {mysqldef => mysql}/set.go (99%) rename {mysqldef => mysql}/set_test.go (99%) rename {mysqldef => mysql}/state.go (99%) rename {mysqldef => mysql}/time.go (99%) rename {mysqldef => mysql}/time_test.go (99%) rename {mysqldef => mysql}/type.go (99%) rename {mysqldef => mysql}/type_test.go (98%) rename {mysqldef => mysql}/util.go (98%) rename {mysqldef => mysql}/util_test.go (98%) diff --git a/bootstrap.go b/bootstrap.go index ead2091356..8de29f1d9b 100644 --- a/bootstrap.go +++ b/bootstrap.go @@ -22,7 +22,7 @@ import ( "runtime/debug" "github.com/ngaut/log" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/errors" "github.com/pingcap/tidb/util/errors2" ) diff --git a/column/column.go b/column/column.go index 6acdfe2c37..d6831c8f97 100644 --- a/column/column.go +++ b/column/column.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/column/column_test.go b/column/column_test.go index d4ae1a2a35..45710b588b 100644 --- a/column/column_test.go +++ b/column/column_test.go @@ -18,7 +18,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/ddl/ddl.go b/ddl/ddl.go index d8130e74cf..2525a9327b 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -30,7 +30,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index 2945b69388..864567a29f 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -23,7 +23,7 @@ import ( "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/stmt" diff --git a/driver.go b/driver.go index 31d58c5acd..ad2f08647c 100644 --- a/driver.go +++ b/driver.go @@ -31,7 +31,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx" qerror "github.com/pingcap/tidb/util/errors" diff --git a/expression/binop.go b/expression/binop.go index d466a3ddd9..8ce8cf2b54 100644 --- a/expression/binop.go +++ b/expression/binop.go @@ -23,7 +23,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/binop_test.go b/expression/binop_test.go index 81f14ce804..ee2c02f3a2 100644 --- a/expression/binop_test.go +++ b/expression/binop_test.go @@ -20,7 +20,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/expression/builtin" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/builtin/groupby.go b/expression/builtin/groupby.go index 69cd07253f..5566c3897f 100644 --- a/expression/builtin/groupby.go +++ b/expression/builtin/groupby.go @@ -24,7 +24,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/kv/memkv" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/builtin/groupby_test.go b/expression/builtin/groupby_test.go index a6116bd968..0730725bda 100644 --- a/expression/builtin/groupby_test.go +++ b/expression/builtin/groupby_test.go @@ -15,7 +15,7 @@ package builtin import ( . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/builtin/string_test.go b/expression/builtin/string_test.go index 878191115c..5c128a47d9 100644 --- a/expression/builtin/string_test.go +++ b/expression/builtin/string_test.go @@ -19,7 +19,7 @@ import ( "time" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func (s *testBuiltinSuite) TestLength(c *C) { diff --git a/expression/builtin/time.go b/expression/builtin/time.go index 22e9e07468..ecfb62002e 100644 --- a/expression/builtin/time.go +++ b/expression/builtin/time.go @@ -21,7 +21,7 @@ import ( "time" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/builtin/time_test.go b/expression/builtin/time_test.go index 151b0ed9ba..c3e662933f 100644 --- a/expression/builtin/time_test.go +++ b/expression/builtin/time_test.go @@ -18,7 +18,7 @@ import ( "time" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func (s *testBuiltinSuite) TestDate(c *C) { diff --git a/expression/cast.go b/expression/cast.go index 42f54420b1..4b0d549e79 100644 --- a/expression/cast.go +++ b/expression/cast.go @@ -18,7 +18,7 @@ import ( "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/cast_test.go b/expression/cast_test.go index 7d969e5f06..72556f40f9 100644 --- a/expression/cast_test.go +++ b/expression/cast_test.go @@ -17,7 +17,7 @@ import ( "errors" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/date_add.go b/expression/date_add.go index b474347f4c..95cd99a0de 100644 --- a/expression/date_add.go +++ b/expression/date_add.go @@ -19,7 +19,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/date_add_test.go b/expression/date_add_test.go index e13f3b2605..6b612dce40 100644 --- a/expression/date_add_test.go +++ b/expression/date_add_test.go @@ -15,7 +15,7 @@ package expression import ( . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) var _ = Suite(&testDateAddSuite{}) diff --git a/expression/extract.go b/expression/extract.go index abd0b9af57..32608cdddb 100644 --- a/expression/extract.go +++ b/expression/extract.go @@ -19,7 +19,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/helper.go b/expression/helper.go index 69554c4354..3e1b693ca7 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -30,7 +30,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression/builtin" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/types" diff --git a/expression/helper_test.go b/expression/helper_test.go index 1eacac2302..d1720fee64 100644 --- a/expression/helper_test.go +++ b/expression/helper_test.go @@ -6,7 +6,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/sessionctx/variable" ) diff --git a/expression/unary.go b/expression/unary.go index 7c58c128b1..f7c8e9ea6b 100644 --- a/expression/unary.go +++ b/expression/unary.go @@ -23,7 +23,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/unary_test.go b/expression/unary_test.go index bef54d7982..aff61503ba 100644 --- a/expression/unary_test.go +++ b/expression/unary_test.go @@ -19,7 +19,7 @@ import ( . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/visitor_test.go b/expression/visitor_test.go index 281c86c00d..b88315bfa6 100644 --- a/expression/visitor_test.go +++ b/expression/visitor_test.go @@ -17,7 +17,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) diff --git a/field/field_test.go b/field/field_test.go index c72d29b6e6..49cb09663c 100644 --- a/field/field_test.go +++ b/field/field_test.go @@ -19,7 +19,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/field/result_field.go b/field/result_field.go index 2eb301b9d2..fd79ab9e77 100644 --- a/field/result_field.go +++ b/field/result_field.go @@ -23,7 +23,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/column" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) const ( diff --git a/field/result_field_test.go b/field/result_field_test.go index 39308aefd5..9d58d86a20 100644 --- a/field/result_field_test.go +++ b/field/result_field_test.go @@ -19,7 +19,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/infoschema/infoschema_test.go b/infoschema/infoschema_test.go index ae197c93f4..f10910ead9 100644 --- a/infoschema/infoschema_test.go +++ b/infoschema/infoschema_test.go @@ -19,7 +19,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/model" - "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/store/localstore" "github.com/pingcap/tidb/store/localstore/goleveldb" "github.com/pingcap/tidb/util/types" @@ -51,7 +51,7 @@ func (*testSuite) TestT(c *C) { ID: 3, Name: colName, Offset: 0, - FieldType: *types.NewFieldType(mysqldef.TypeLonglong), + FieldType: *types.NewFieldType(mysql.TypeLonglong), } idxInfo := &model.IndexInfo{ diff --git a/mysqldef/bit.go b/mysql/bit.go similarity index 99% rename from mysqldef/bit.go rename to mysql/bit.go index 985c5b1d58..4f76dceb18 100644 --- a/mysqldef/bit.go +++ b/mysql/bit.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "fmt" diff --git a/mysqldef/bit_test.go b/mysql/bit_test.go similarity index 98% rename from mysqldef/bit_test.go rename to mysql/bit_test.go index b2ca9d1080..49652206df 100644 --- a/mysqldef/bit_test.go +++ b/mysql/bit_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import . "github.com/pingcap/check" diff --git a/mysqldef/charset.go b/mysql/charset.go similarity index 99% rename from mysqldef/charset.go rename to mysql/charset.go index fb4809c362..962d792e50 100644 --- a/mysqldef/charset.go +++ b/mysql/charset.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // CharsetIDs maps charset name to its default collation ID. var CharsetIDs = map[string]uint8{ diff --git a/mysqldef/const.go b/mysql/const.go similarity index 99% rename from mysqldef/const.go rename to mysql/const.go index 3f483c1eb0..e0f49e4294 100644 --- a/mysqldef/const.go +++ b/mysql/const.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // Version informations. const ( diff --git a/mysqldef/decimal.go b/mysql/decimal.go similarity index 99% rename from mysqldef/decimal.go rename to mysql/decimal.go index 2ff129af19..87bd66e6b3 100644 --- a/mysqldef/decimal.go +++ b/mysql/decimal.go @@ -57,7 +57,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // Decimal implements an arbitrary precision fixed-point decimal. // diff --git a/mysqldef/decimal_test.go b/mysql/decimal_test.go similarity index 99% rename from mysqldef/decimal_test.go rename to mysql/decimal_test.go index 60bb409c88..ab9748e232 100644 --- a/mysqldef/decimal_test.go +++ b/mysql/decimal_test.go @@ -57,7 +57,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "encoding/json" diff --git a/mysqldef/enum.go b/mysql/enum.go similarity index 98% rename from mysqldef/enum.go rename to mysql/enum.go index 96c4712800..425a3e1b52 100644 --- a/mysqldef/enum.go +++ b/mysql/enum.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "strconv" diff --git a/mysqldef/enum_test.go b/mysql/enum_test.go similarity index 98% rename from mysqldef/enum_test.go rename to mysql/enum_test.go index 0c06660493..154ac73802 100644 --- a/mysqldef/enum_test.go +++ b/mysql/enum_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( . "github.com/pingcap/check" diff --git a/mysqldef/errcode.go b/mysql/errcode.go similarity index 99% rename from mysqldef/errcode.go rename to mysql/errcode.go index d486c36c90..8de53d7e71 100644 --- a/mysqldef/errcode.go +++ b/mysql/errcode.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // MySQL error code. // This value is numeric. It is not portable to other database systems. diff --git a/mysqldef/errname.go b/mysql/errname.go similarity index 99% rename from mysqldef/errname.go rename to mysql/errname.go index bc98fed16b..ad11ec746e 100644 --- a/mysqldef/errname.go +++ b/mysql/errname.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // MySQLErrName maps error code to MySQL error messages. var MySQLErrName = map[uint16]string{ diff --git a/mysqldef/error.go b/mysql/error.go similarity index 99% rename from mysqldef/error.go rename to mysql/error.go index 6b75670fc8..43246a4a2a 100644 --- a/mysqldef/error.go +++ b/mysql/error.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "errors" diff --git a/mysqldef/error_test.go b/mysql/error_test.go similarity index 98% rename from mysqldef/error_test.go rename to mysql/error_test.go index 10f9dd2d8e..1e499ff304 100644 --- a/mysqldef/error_test.go +++ b/mysql/error_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( . "github.com/pingcap/check" diff --git a/mysqldef/fsp.go b/mysql/fsp.go similarity index 99% rename from mysqldef/fsp.go rename to mysql/fsp.go index 46384d987e..820fabb3d1 100644 --- a/mysqldef/fsp.go +++ b/mysql/fsp.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "math" diff --git a/mysqldef/hex.go b/mysql/hex.go similarity index 99% rename from mysqldef/hex.go rename to mysql/hex.go index 0a3b160d5d..5858b5d969 100644 --- a/mysqldef/hex.go +++ b/mysql/hex.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "encoding/hex" diff --git a/mysqldef/hex_test.go b/mysql/hex_test.go similarity index 98% rename from mysqldef/hex_test.go rename to mysql/hex_test.go index ba1e3e55b4..551e245b64 100644 --- a/mysqldef/hex_test.go +++ b/mysql/hex_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "strconv" diff --git a/mysqldef/set.go b/mysql/set.go similarity index 99% rename from mysqldef/set.go rename to mysql/set.go index 613a938f55..8c0788a614 100644 --- a/mysqldef/set.go +++ b/mysql/set.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "strconv" diff --git a/mysqldef/set_test.go b/mysql/set_test.go similarity index 99% rename from mysqldef/set_test.go rename to mysql/set_test.go index 0469da6f23..69887db4b6 100644 --- a/mysqldef/set_test.go +++ b/mysql/set_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( . "github.com/pingcap/check" diff --git a/mysqldef/state.go b/mysql/state.go similarity index 99% rename from mysqldef/state.go rename to mysql/state.go index 9dfb4e9c14..67fcd0e5af 100644 --- a/mysqldef/state.go +++ b/mysql/state.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql const ( // DefaultMySQLState is default state of the mySQL diff --git a/mysqldef/time.go b/mysql/time.go similarity index 99% rename from mysqldef/time.go rename to mysql/time.go index e7f909b301..7e094a5625 100644 --- a/mysqldef/time.go +++ b/mysql/time.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "bytes" diff --git a/mysqldef/time_test.go b/mysql/time_test.go similarity index 99% rename from mysqldef/time_test.go rename to mysql/time_test.go index f2db328247..f58851ace2 100644 --- a/mysqldef/time_test.go +++ b/mysql/time_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "testing" diff --git a/mysqldef/type.go b/mysql/type.go similarity index 99% rename from mysqldef/type.go rename to mysql/type.go index c9230d62b7..36f3441a94 100644 --- a/mysqldef/type.go +++ b/mysql/type.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // MySQL type informations. const ( diff --git a/mysqldef/type_test.go b/mysql/type_test.go similarity index 98% rename from mysqldef/type_test.go rename to mysql/type_test.go index 47175d7d03..53d6fe646b 100644 --- a/mysqldef/type_test.go +++ b/mysql/type_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import . "github.com/pingcap/check" diff --git a/mysqldef/util.go b/mysql/util.go similarity index 98% rename from mysqldef/util.go rename to mysql/util.go index 11f08c03a3..7a5cf72e1f 100644 --- a/mysqldef/util.go +++ b/mysql/util.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // GetDefaultFieldLength is used for Interger Types, Flen is the display length. // Call this when no Flen assigned in ddl. diff --git a/mysqldef/util_test.go b/mysql/util_test.go similarity index 98% rename from mysqldef/util_test.go rename to mysql/util_test.go index e2e98a9f9d..acda0a42b6 100644 --- a/mysqldef/util_test.go +++ b/mysql/util_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import "testing" diff --git a/parser/coldef/col_def.go b/parser/coldef/col_def.go index 8010c08dec..d11609c373 100644 --- a/parser/coldef/col_def.go +++ b/parser/coldef/col_def.go @@ -21,7 +21,7 @@ import ( "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" diff --git a/parser/coldef/opt.go b/parser/coldef/opt.go index 8287e674ed..9a31cb0add 100644 --- a/parser/coldef/opt.go +++ b/parser/coldef/opt.go @@ -18,7 +18,7 @@ import ( "strings" "github.com/pingcap/tidb/expression" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) // FloatOpt is used for parsing floating-point type option from SQL. diff --git a/parser/parser.y b/parser/parser.y index bbb81ae742..ef5ed414e8 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -29,7 +29,7 @@ import ( "fmt" "strings" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/ddl" diff --git a/parser/scanner.l b/parser/scanner.l index dcedc40153..8ce33d9011 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -27,7 +27,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/util/stringutil" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) type lexer struct { diff --git a/plan/plans/final_test.go b/plan/plans/final_test.go index 8ba45d898d..d9400ca03d 100644 --- a/plan/plans/final_test.go +++ b/plan/plans/final_test.go @@ -18,7 +18,7 @@ import ( "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" "github.com/pingcap/tidb/util/charset" diff --git a/plan/plans/from_test.go b/plan/plans/from_test.go index 1c31d6cb04..9249ef8dee 100644 --- a/plan/plans/from_test.go +++ b/plan/plans/from_test.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" diff --git a/plan/plans/index_test.go b/plan/plans/index_test.go index 0cc49fbeac..1a713e0e27 100644 --- a/plan/plans/index_test.go +++ b/plan/plans/index_test.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" diff --git a/plan/plans/info.go b/plan/plans/info.go index 9b222c7a6a..6c45fb002c 100644 --- a/plan/plans/info.go +++ b/plan/plans/info.go @@ -25,7 +25,7 @@ import ( "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/charset" diff --git a/plan/plans/join_test.go b/plan/plans/join_test.go index a14fa38525..c38824a22c 100644 --- a/plan/plans/join_test.go +++ b/plan/plans/join_test.go @@ -20,7 +20,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" diff --git a/plan/plans/plans.go b/plan/plans/plans.go index 0b3c689ad7..393aa98f28 100644 --- a/plan/plans/plans.go +++ b/plan/plans/plans.go @@ -22,7 +22,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/format" diff --git a/plan/plans/show.go b/plan/plans/show.go index 87eff2e91a..62a045412d 100644 --- a/plan/plans/show.go +++ b/plan/plans/show.go @@ -25,7 +25,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" diff --git a/plan/plans/show_test.go b/plan/plans/show_test.go index 7b36ebf810..acde709fe6 100644 --- a/plan/plans/show_test.go +++ b/plan/plans/show_test.go @@ -22,7 +22,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" diff --git a/plan/plans/union_test.go b/plan/plans/union_test.go index 478f7b8f22..70653ae024 100644 --- a/plan/plans/union_test.go +++ b/plan/plans/union_test.go @@ -18,7 +18,7 @@ import ( "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" diff --git a/session.go b/session.go index 31c6a703a7..b7c9aa9add 100644 --- a/session.go +++ b/session.go @@ -31,7 +31,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/db" diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index a71b3558b8..5e03948130 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -15,7 +15,7 @@ package variable import ( "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) // SessionVars is to handle user-defined or global variables in current session. diff --git a/stmt/stmts/account_manage.go b/stmt/stmts/account_manage.go index 8c169ef354..cec31dad1e 100644 --- a/stmt/stmts/account_manage.go +++ b/stmt/stmts/account_manage.go @@ -23,7 +23,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/stmt" diff --git a/stmt/stmts/create_test.go b/stmt/stmts/create_test.go index d53e8270d2..15e0cbf1bc 100644 --- a/stmt/stmts/create_test.go +++ b/stmt/stmts/create_test.go @@ -21,7 +21,7 @@ import ( "github.com/ngaut/log" . "github.com/pingcap/check" "github.com/pingcap/tidb" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func TestT(t *testing.T) { diff --git a/stmt/stmts/grant.go b/stmt/stmts/grant.go index d4dbcbab2a..7ba4152656 100644 --- a/stmt/stmts/grant.go +++ b/stmt/stmts/grant.go @@ -25,7 +25,7 @@ import ( "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx" @@ -221,7 +221,7 @@ func (s *GrantStmt) grantPriv(ctx context.Context, priv *coldef.PrivElem, user * } return s.grantColumnPriv(ctx, priv, user) default: - return errors.Errorf("Unknown grant level: %s", s.Level) + return errors.Errorf("Unknown grant level: %#v", s.Level) } } diff --git a/stmt/stmts/grant_test.go b/stmt/stmts/grant_test.go index 2b4ace611b..eae406ccf9 100644 --- a/stmt/stmts/grant_test.go +++ b/stmt/stmts/grant_test.go @@ -18,7 +18,7 @@ import ( "strings" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func (s *testStmtSuite) TestGrantGlobal(c *C) { diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index 374349b975..61251f56f8 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/variable" diff --git a/stmt/stmts/stmt_helper.go b/stmt/stmts/stmt_helper.go index d8ebe8682a..cca52a8f5b 100644 --- a/stmt/stmts/stmt_helper.go +++ b/stmt/stmts/stmt_helper.go @@ -18,7 +18,7 @@ import ( "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" ) diff --git a/stmt/stmts/transaction.go b/stmt/stmts/transaction.go index b82c398f6b..292eb61936 100644 --- a/stmt/stmts/transaction.go +++ b/stmt/stmts/transaction.go @@ -19,7 +19,7 @@ package stmts import ( "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/stmt" diff --git a/stmt/stmts/update.go b/stmt/stmts/update.go index 4ea5da59dc..1b6f46f4dc 100644 --- a/stmt/stmts/update.go +++ b/stmt/stmts/update.go @@ -26,7 +26,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset" diff --git a/table/tables/tables.go b/table/tables/tables.go index fdb3f1e2f0..fe6184ee69 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -31,7 +31,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util" diff --git a/table/tables/tables_test.go b/table/tables/tables_test.go index 0221fa6147..edaeeb8c2d 100644 --- a/table/tables/tables_test.go +++ b/table/tables/tables_test.go @@ -22,7 +22,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/store/localstore" "github.com/pingcap/tidb/store/localstore/goleveldb" diff --git a/tidb-server/server/conn.go b/tidb-server/server/conn.go index 898f91bd84..3b2f3d8a55 100644 --- a/tidb-server/server/conn.go +++ b/tidb-server/server/conn.go @@ -45,7 +45,7 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/arena" "github.com/pingcap/tidb/util/errors2" "github.com/pingcap/tidb/util/hack" diff --git a/tidb-server/server/conn_stmt.go b/tidb-server/server/conn_stmt.go index aacaf4933a..a4d32aa3f3 100644 --- a/tidb-server/server/conn_stmt.go +++ b/tidb-server/server/conn_stmt.go @@ -40,7 +40,7 @@ import ( "strconv" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/hack" ) diff --git a/tidb-server/server/driver_tidb.go b/tidb-server/server/driver_tidb.go index 79c23ae93c..1767f61798 100644 --- a/tidb-server/server/driver_tidb.go +++ b/tidb-server/server/driver_tidb.go @@ -18,7 +18,7 @@ import ( "github.com/pingcap/tidb" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/util/errors2" "github.com/pingcap/tidb/util/types" diff --git a/tidb-server/server/packetio.go b/tidb-server/server/packetio.go index 225922edc5..150cc7c4d2 100644 --- a/tidb-server/server/packetio.go +++ b/tidb-server/server/packetio.go @@ -40,7 +40,7 @@ import ( "net" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) type packetIO struct { diff --git a/tidb-server/server/server.go b/tidb-server/server/server.go index 49681e4da1..e3dd087663 100644 --- a/tidb-server/server/server.go +++ b/tidb-server/server/server.go @@ -37,7 +37,7 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/arena" ) diff --git a/tidb-server/server/util.go b/tidb-server/server/util.go index 3dbf420afc..08362bc440 100644 --- a/tidb-server/server/util.go +++ b/tidb-server/server/util.go @@ -42,7 +42,7 @@ import ( "time" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/arena" "github.com/pingcap/tidb/util/hack" ) diff --git a/tidb-server/server/util_test.go b/tidb-server/server/util_test.go index 81a14e35de..b5d1cb454b 100644 --- a/tidb-server/server/util_test.go +++ b/tidb-server/server/util_test.go @@ -15,7 +15,7 @@ package server import ( . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) var _ = Suite(&testUtilSuite{}) diff --git a/tidb_test.go b/tidb_test.go index 6fa9a1b2cd..fdd82907da 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -26,7 +26,7 @@ import ( "github.com/ngaut/log" . "github.com/pingcap/check" "github.com/pingcap/tidb/kv" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/errors2" diff --git a/util/codec/codec.go b/util/codec/codec.go index d679834575..edd951d533 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -18,7 +18,7 @@ import ( "time" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) var ( diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index a3a0ccdeff..2334ea46f8 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -19,7 +19,7 @@ import ( "testing" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func TestT(t *testing.T) { diff --git a/util/codec/decimal.go b/util/codec/decimal.go index f33b5ab39f..13f39d0850 100644 --- a/util/codec/decimal.go +++ b/util/codec/decimal.go @@ -18,7 +18,7 @@ import ( "math/big" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) const ( diff --git a/util/codec/decimal_test.go b/util/codec/decimal_test.go index 28532e766e..f957ee5351 100644 --- a/util/codec/decimal_test.go +++ b/util/codec/decimal_test.go @@ -15,7 +15,7 @@ package codec import ( . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) var _ = Suite(&testDecimalSuite{}) diff --git a/util/types/compare.go b/util/types/compare.go index 3c894bd228..ff035adedb 100644 --- a/util/types/compare.go +++ b/util/types/compare.go @@ -19,7 +19,7 @@ package types import ( "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) // CompareInt64 returns an integer comparing the int64 x to y. diff --git a/util/types/compare_test.go b/util/types/compare_test.go index 7e492497df..0a17c92037 100644 --- a/util/types/compare_test.go +++ b/util/types/compare_test.go @@ -17,7 +17,7 @@ import ( "time" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) var _ = Suite(&testCompareSuite{}) diff --git a/util/types/convert.go b/util/types/convert.go index f3c7c4079e..f28721d479 100644 --- a/util/types/convert.go +++ b/util/types/convert.go @@ -25,7 +25,7 @@ import ( "unicode" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/charset" ) diff --git a/util/types/convert_test.go b/util/types/convert_test.go index 300d54aea6..1bc10f9d4a 100644 --- a/util/types/convert_test.go +++ b/util/types/convert_test.go @@ -20,7 +20,7 @@ import ( "fmt" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/charset" ) diff --git a/util/types/etc.go b/util/types/etc.go index 1df8a83eff..fdbb1e7b3f 100644 --- a/util/types/etc.go +++ b/util/types/etc.go @@ -23,7 +23,7 @@ import ( "strings" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/errors2" diff --git a/util/types/etc_test.go b/util/types/etc_test.go index 20aa8d6626..5fbf15d113 100644 --- a/util/types/etc_test.go +++ b/util/types/etc_test.go @@ -20,7 +20,7 @@ import ( "time" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func TestT(t *testing.T) { diff --git a/util/types/field_type.go b/util/types/field_type.go index ec96912bcf..daf8b354f5 100644 --- a/util/types/field_type.go +++ b/util/types/field_type.go @@ -21,7 +21,7 @@ import ( "fmt" "strings" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/charset" ) diff --git a/util/types/field_type_test.go b/util/types/field_type_test.go index 558c799d8a..a0a069867a 100644 --- a/util/types/field_type_test.go +++ b/util/types/field_type_test.go @@ -15,7 +15,7 @@ package types import ( . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) var _ = Suite(&testFieldTypeSuite{}) From e6b04888e73c0fdb8c4366fc18c89e8f3a94d985 Mon Sep 17 00:00:00 2001 From: dongxu Date: Mon, 19 Oct 2015 16:39:11 +0800 Subject: [PATCH 47/54] mvcc-gc: rename compact policy address comments --- kv/compactor.go | 4 ++-- store/localstore/compactor.go | 6 +++--- store/localstore/kv.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/kv/compactor.go b/kv/compactor.go index 47b48d8123..8c3dcadb1b 100644 --- a/kv/compactor.go +++ b/kv/compactor.go @@ -15,8 +15,8 @@ package kv import "time" -// CompactorPolicy defines gc policy of MVCC storage. -type CompactorPolicy struct { +// CompactPolicy defines gc policy of MVCC storage. +type CompactPolicy struct { // SafePoint specifies SafePoint int // TriggerInterval specifies how often should the compactor diff --git a/store/localstore/compactor.go b/store/localstore/compactor.go index 96e983b109..ae29dc207f 100644 --- a/store/localstore/compactor.go +++ b/store/localstore/compactor.go @@ -30,7 +30,7 @@ const ( deleteWorkerCnt = 3 ) -var localCompactorDefaultPolicy = kv.CompactorPolicy{ +var localCompactDefaultPolicy = kv.CompactPolicy{ SafePoint: 20 * 1000, // in ms TriggerInterval: 1 * time.Second, BatchDeleteCnt: 100, @@ -43,7 +43,7 @@ type localstoreCompactor struct { delCh chan kv.EncodedKey ticker *time.Ticker db engine.DB - policy kv.CompactorPolicy + policy kv.CompactPolicy } func (gc *localstoreCompactor) OnSet(k kv.Key) { @@ -182,7 +182,7 @@ func (gc *localstoreCompactor) Stop() { close(gc.stopCh) } -func newLocalCompactor(policy kv.CompactorPolicy, db engine.DB) *localstoreCompactor { +func newLocalCompactor(policy kv.CompactPolicy, db engine.DB) *localstoreCompactor { return &localstoreCompactor{ recentKeys: make(map[string]struct{}), stopCh: make(chan struct{}), diff --git a/store/localstore/kv.go b/store/localstore/kv.go index 9e45a2d873..089874381d 100644 --- a/store/localstore/kv.go +++ b/store/localstore/kv.go @@ -86,7 +86,7 @@ func (d Driver) Open(schema string) (kv.Storage, error) { uuid: uuid.NewV4().String(), path: schema, db: db, - compactor: newLocalCompactor(localCompactorDefaultPolicy, db), + compactor: newLocalCompactor(localCompactDefaultPolicy, db), } mc.cache[schema] = s s.compactor.Start() From 190302b78ed5bbe429717699bca8e648318c74f0 Mon Sep 17 00:00:00 2001 From: dongxu Date: Mon, 19 Oct 2015 16:42:44 +0800 Subject: [PATCH 48/54] mvcc-gc: little refactor address comments --- kv/compactor.go | 6 +++--- store/localstore/compactor.go | 8 +++----- store/localstore/compactor_test.go | 6 +++--- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/kv/compactor.go b/kv/compactor.go index 8c3dcadb1b..47ebd7a619 100644 --- a/kv/compactor.go +++ b/kv/compactor.go @@ -23,7 +23,7 @@ type CompactPolicy struct { // scans outdated data. TriggerInterval time.Duration // BatchDeleteCnt specifies the batch size for - // delete outdated data transaction. + // deleting outdated data transaction. BatchDeleteCnt int } @@ -31,9 +31,9 @@ type CompactPolicy struct { type Compactor interface { // OnGet is the hook point on Txn.Get. OnGet(k Key) - // OnSet is the hook point on Txn.Set + // OnSet is the hook point on Txn.Set. OnSet(k Key) - // OnDelete is the hook point on Txn.Delete + // OnDelete is the hook point on Txn.Delete. OnDelete(k Key) // Compact is the function removes the given key. Compact(ctx interface{}, k Key) error diff --git a/store/localstore/compactor.go b/store/localstore/compactor.go index ae29dc207f..64b875c251 100644 --- a/store/localstore/compactor.go +++ b/store/localstore/compactor.go @@ -84,11 +84,10 @@ func (gc *localstoreCompactor) getAllVersions(k kv.Key) ([]kv.EncodedKey, error) func (gc *localstoreCompactor) deleteWorker() { cnt := 0 batch := gc.db.NewBatch() -L: for { select { case <-gc.stopCh: - break L + return case key := <-gc.delCh: { cnt++ @@ -108,11 +107,11 @@ L: } func (gc *localstoreCompactor) checkExpiredKeysWorker() { -L: for { select { case <-gc.stopCh: - break L + log.Info("GC stopped") + return case <-gc.ticker.C: log.Info("GC trigger") gc.mu.Lock() @@ -128,7 +127,6 @@ L: } } } - log.Info("GC Stopped") } func (gc *localstoreCompactor) filterExpiredKeys(keys []kv.EncodedKey) []kv.EncodedKey { diff --git a/store/localstore/compactor_test.go b/store/localstore/compactor_test.go index a8f3db0424..027437be9c 100644 --- a/store/localstore/compactor_test.go +++ b/store/localstore/compactor_test.go @@ -16,7 +16,6 @@ package localstore import ( "time" - "github.com/ngaut/log" . "github.com/pingcap/check" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/localstore/engine" @@ -32,7 +31,6 @@ func count(db engine.DB) int { defer it.Release() totalCnt := 0 for it.Next() { - log.Error(it.Key()) totalCnt++ } return totalCnt @@ -43,7 +41,7 @@ func (s *localstoreCompactorTestSuite) TestCompactor(c *C) { db := store.(*dbStore).db store.(*dbStore).compactor.Stop() - policy := kv.CompactorPolicy{ + policy := kv.CompactPolicy{ SafePoint: 500, BatchDeleteCnt: 1, TriggerInterval: 100 * time.Millisecond, @@ -84,4 +82,6 @@ func (s *localstoreCompactorTestSuite) TestCompactor(c *C) { // Do background GC t = count(db) c.Assert(t, Equals, 3) + + compactor.Stop() } From 71123af77076a52c4cbf2e3b80b2bd80fda60800 Mon Sep 17 00:00:00 2001 From: dongxu Date: Mon, 19 Oct 2015 17:28:16 +0800 Subject: [PATCH 49/54] mvcc-gc: elegant shutdown GC address for comments --- store/localstore/compactor.go | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/store/localstore/compactor.go b/store/localstore/compactor.go index 64b875c251..4de3862300 100644 --- a/store/localstore/compactor.go +++ b/store/localstore/compactor.go @@ -37,13 +37,14 @@ var localCompactDefaultPolicy = kv.CompactPolicy{ } type localstoreCompactor struct { - mu sync.Mutex - recentKeys map[string]struct{} - stopCh chan struct{} - delCh chan kv.EncodedKey - ticker *time.Ticker - db engine.DB - policy kv.CompactPolicy + mu sync.Mutex + recentKeys map[string]struct{} + stopCh chan struct{} + delCh chan kv.EncodedKey + workerWaitGroup sync.WaitGroup + ticker *time.Ticker + db engine.DB + policy kv.CompactPolicy } func (gc *localstoreCompactor) OnSet(k kv.Key) { @@ -82,6 +83,8 @@ func (gc *localstoreCompactor) getAllVersions(k kv.Key) ([]kv.EncodedKey, error) } func (gc *localstoreCompactor) deleteWorker() { + gc.workerWaitGroup.Add(1) + defer gc.workerWaitGroup.Done() cnt := 0 batch := gc.db.NewBatch() for { @@ -107,6 +110,8 @@ func (gc *localstoreCompactor) deleteWorker() { } func (gc *localstoreCompactor) checkExpiredKeysWorker() { + gc.workerWaitGroup.Add(1) + defer gc.workerWaitGroup.Done() for { select { case <-gc.stopCh: @@ -178,6 +183,8 @@ func (gc *localstoreCompactor) Start() { func (gc *localstoreCompactor) Stop() { gc.ticker.Stop() close(gc.stopCh) + // Wait for all workers to finish. + gc.workerWaitGroup.Wait() } func newLocalCompactor(policy kv.CompactPolicy, db engine.DB) *localstoreCompactor { From 1be3cc5d2ec54d009cc41fede441e896002a5c17 Mon Sep 17 00:00:00 2001 From: dongxu Date: Mon, 19 Oct 2015 17:40:29 +0800 Subject: [PATCH 50/54] mvcc-gc: little refactor, rename address comment --- store/localstore/compactor.go | 4 ++-- store/localstore/kv.go | 2 -- store/localstore/local_version_provider.go | 6 +++--- store/localstore/txn.go | 2 -- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/store/localstore/compactor.go b/store/localstore/compactor.go index 4de3862300..d0eabb221b 100644 --- a/store/localstore/compactor.go +++ b/store/localstore/compactor.go @@ -145,9 +145,9 @@ func (gc *localstoreCompactor) filterExpiredKeys(keys []kv.EncodedKey) []kv.Enco panic(err) } ts := localVersionToTimestamp(ver) - currentTs := time.Now().UnixNano() / int64(time.Millisecond) + currentTS := time.Now().UnixNano() / int64(time.Millisecond) // Check timeout keys. - if currentTs-int64(ts) >= int64(gc.policy.SafePoint) { + if currentTS-int64(ts) >= int64(gc.policy.SafePoint) { // Skip first version. if first { first = false diff --git a/store/localstore/kv.go b/store/localstore/kv.go index 089874381d..ae96bb5cf0 100644 --- a/store/localstore/kv.go +++ b/store/localstore/kv.go @@ -15,7 +15,6 @@ package localstore import ( "sync" - "time" "github.com/juju/errors" "github.com/ngaut/log" @@ -118,7 +117,6 @@ func (s *dbStore) Begin() (kv.Transaction, error) { return nil, err } txn := &dbTxn{ - startTs: time.Now(), tid: beginVer.Ver, valid: true, store: s, diff --git a/store/localstore/local_version_provider.go b/store/localstore/local_version_provider.go index ad63831341..df235e04b2 100644 --- a/store/localstore/local_version_provider.go +++ b/store/localstore/local_version_provider.go @@ -16,7 +16,7 @@ var ErrOverflow = errors.New("overflow when allocating new version") // LocalVersionProvider uses local timestamp for version. type LocalVersionProvider struct { mu sync.Mutex - lastTimeStampTs uint64 + lastTimeStampTS uint64 n uint64 } @@ -31,14 +31,14 @@ func (l *LocalVersionProvider) CurrentVersion() (kv.Version, error) { var ts uint64 ts = uint64((time.Now().UnixNano() / int64(time.Millisecond)) << timePrecisionOffset) - if l.lastTimeStampTs == uint64(ts) { + if l.lastTimeStampTS == uint64(ts) { l.n++ if l.n >= 1< Date: Mon, 19 Oct 2015 17:45:30 +0800 Subject: [PATCH 51/54] mvcc-gc: rename some variables. address code review comments. --- store/localstore/local_version_provider.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/store/localstore/local_version_provider.go b/store/localstore/local_version_provider.go index df235e04b2..606a8ce382 100644 --- a/store/localstore/local_version_provider.go +++ b/store/localstore/local_version_provider.go @@ -15,9 +15,9 @@ var ErrOverflow = errors.New("overflow when allocating new version") // LocalVersionProvider uses local timestamp for version. type LocalVersionProvider struct { - mu sync.Mutex - lastTimeStampTS uint64 - n uint64 + mu sync.Mutex + lastTimeStamp uint64 + n uint64 } const ( @@ -31,14 +31,14 @@ func (l *LocalVersionProvider) CurrentVersion() (kv.Version, error) { var ts uint64 ts = uint64((time.Now().UnixNano() / int64(time.Millisecond)) << timePrecisionOffset) - if l.lastTimeStampTS == uint64(ts) { + if l.lastTimeStamp == uint64(ts) { l.n++ if l.n >= 1< Date: Mon, 19 Oct 2015 18:09:55 +0800 Subject: [PATCH 52/54] mvcc-gc: add some comments address review comments. --- store/localstore/local_version_provider.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/store/localstore/local_version_provider.go b/store/localstore/local_version_provider.go index 606a8ce382..997bf786b6 100644 --- a/store/localstore/local_version_provider.go +++ b/store/localstore/local_version_provider.go @@ -16,8 +16,10 @@ var ErrOverflow = errors.New("overflow when allocating new version") // LocalVersionProvider uses local timestamp for version. type LocalVersionProvider struct { mu sync.Mutex - lastTimeStamp uint64 - n uint64 + lastTimestamp uint64 + // logical guaranteed version's monotonic increasing for calls when lastTimestamp + // are equal. + logical uint64 } const ( @@ -31,15 +33,15 @@ func (l *LocalVersionProvider) CurrentVersion() (kv.Version, error) { var ts uint64 ts = uint64((time.Now().UnixNano() / int64(time.Millisecond)) << timePrecisionOffset) - if l.lastTimeStamp == uint64(ts) { - l.n++ - if l.n >= 1<= 1< Date: Mon, 19 Oct 2015 20:02:10 +0800 Subject: [PATCH 53/54] kv: remove useless keyCmpFn from Seek/Next iterface. --- ddl/ddl.go | 4 ++-- kv/index_iter.go | 10 +++++----- kv/iter.go | 2 +- kv/kv.go | 4 ++-- kv/union_iter.go | 4 ++-- plan/plans/from.go | 4 ++-- store/localstore/kv_test.go | 32 ++++++++++++++++---------------- store/localstore/mvcc_test.go | 18 +++++++++--------- store/localstore/snapshot.go | 4 ++-- store/localstore/txn.go | 8 +------- table/tables/tables.go | 2 +- util/prefix_helper.go | 8 ++++---- 12 files changed, 47 insertions(+), 53 deletions(-) diff --git a/ddl/ddl.go b/ddl/ddl.go index 2525a9327b..daab43b48c 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -512,7 +512,7 @@ func updateOldRows(ctx context.Context, t *tables.Table, col *column.Col) error if err != nil { return errors.Trace(err) } - it, err := txn.Seek([]byte(t.FirstKey()), nil) + it, err := txn.Seek([]byte(t.FirstKey())) if err != nil { return errors.Trace(err) } @@ -685,7 +685,7 @@ func (d *ddl) buildIndex(ctx context.Context, t table.Table, idxInfo *model.Inde if err != nil { return errors.Trace(err) } - it, err := txn.Seek([]byte(firstKey), nil) + it, err := txn.Seek([]byte(firstKey)) if err != nil { return errors.Trace(err) } diff --git a/kv/index_iter.go b/kv/index_iter.go index 98c60557ec..000c9f13ea 100644 --- a/kv/index_iter.go +++ b/kv/index_iter.go @@ -87,7 +87,7 @@ func (c *indexIter) Next() (k []interface{}, h int64, err error) { k = vv } // update new iter to next - newIt, err := c.it.Next(hasPrefix([]byte(c.prefix))) + newIt, err := c.it.Next() if err != nil { return nil, 0, errors.Trace(err) } @@ -198,7 +198,7 @@ func hasPrefix(prefix []byte) FnKeyCmp { // Drop removes the KV index from store. func (c *kvIndex) Drop(txn Transaction) error { prefix := []byte(c.prefix) - it, err := txn.Seek(Key(prefix), hasPrefix(prefix)) + it, err := txn.Seek(Key(prefix)) if err != nil { return errors.Trace(err) } @@ -213,7 +213,7 @@ func (c *kvIndex) Drop(txn Transaction) error { if err != nil { return errors.Trace(err) } - it, err = it.Next(hasPrefix(prefix)) + it, err = it.Next() if err != nil { return errors.Trace(err) } @@ -227,7 +227,7 @@ func (c *kvIndex) Seek(txn Transaction, indexedValues []interface{}) (iter Index if err != nil { return nil, false, errors.Trace(err) } - it, err := txn.Seek(keyBuf, hasPrefix([]byte(c.prefix))) + it, err := txn.Seek(keyBuf) if err != nil { return nil, false, errors.Trace(err) } @@ -242,7 +242,7 @@ func (c *kvIndex) Seek(txn Transaction, indexedValues []interface{}) (iter Index // SeekFirst returns an iterator which points to the first entry of the KV index. func (c *kvIndex) SeekFirst(txn Transaction) (iter IndexIterator, err error) { prefix := []byte(c.prefix) - it, err := txn.Seek(prefix, hasPrefix(prefix)) + it, err := txn.Seek(prefix) if err != nil { return nil, errors.Trace(err) } diff --git a/kv/iter.go b/kv/iter.go index 7dc7beaecf..03a99c5a82 100644 --- a/kv/iter.go +++ b/kv/iter.go @@ -72,7 +72,7 @@ func DecodeValue(data []byte) ([]interface{}, error) { func NextUntil(it Iterator, fn FnKeyCmp) (Iterator, error) { var err error for it.Valid() && !fn([]byte(it.Key())) { - it, err = it.Next(nil) + it, err = it.Next() if err != nil { return nil, errors.Trace(err) } diff --git a/kv/kv.go b/kv/kv.go index b93e9dc3a2..6681cc6f0f 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -104,7 +104,7 @@ type Transaction interface { // Set sets the value for key k as v into KV store. Set(k Key, v []byte) error // Seek searches for the entry with key k in KV store. - Seek(k Key, fnKeyCmp func(key Key) bool) (Iterator, error) + Seek(k Key) (Iterator, error) // Inc increases the value for key k in KV store by step. Inc(k Key, step int64) (int64, error) // GetInt64 get int64 which created by Inc method. @@ -171,7 +171,7 @@ type FnKeyCmp func(key Key) bool // Iterator is the interface for a interator on KV store. type Iterator interface { - Next(FnKeyCmp) (Iterator, error) + Next() (Iterator, error) Value() []byte Key() string Valid() bool diff --git a/kv/union_iter.go b/kv/union_iter.go index 6b677c433c..a25f81e776 100644 --- a/kv/union_iter.go +++ b/kv/union_iter.go @@ -51,7 +51,7 @@ func (iter *UnionIter) dirtyNext() { // Go next and update valid status. func (iter *UnionIter) snapshotNext() { - iter.snapshotIt, _ = iter.snapshotIt.Next(nil) + iter.snapshotIt, _ = iter.snapshotIt.Next() iter.snapshotValid = iter.snapshotIt.Valid() } @@ -116,7 +116,7 @@ func (iter *UnionIter) updateCur() { } // Next implements the Iterator Next interface. -func (iter *UnionIter) Next(f FnKeyCmp) (Iterator, error) { +func (iter *UnionIter) Next() (Iterator, error) { if !iter.curIsDirty { iter.snapshotNext() } else { diff --git a/plan/plans/from.go b/plan/plans/from.go index a1f844fd86..a0c1450404 100644 --- a/plan/plans/from.go +++ b/plan/plans/from.go @@ -68,7 +68,7 @@ func (r *TableNilPlan) Next(ctx context.Context) (row *plan.Row, err error) { if err != nil { return nil, errors.Trace(err) } - r.iter, err = txn.Seek([]byte(r.T.FirstKey()), nil) + r.iter, err = txn.Seek([]byte(r.T.FirstKey())) if err != nil { return nil, errors.Trace(err) } @@ -274,7 +274,7 @@ func (r *TableDefaultPlan) Next(ctx context.Context) (row *plan.Row, err error) if err != nil { return nil, errors.Trace(err) } - r.iter, err = txn.Seek([]byte(r.T.FirstKey()), nil) + r.iter, err = txn.Seek([]byte(r.T.FirstKey())) if err != nil { return nil, errors.Trace(err) } diff --git a/store/localstore/kv_test.go b/store/localstore/kv_test.go index d09c35043d..5b28af06f1 100644 --- a/store/localstore/kv_test.go +++ b/store/localstore/kv_test.go @@ -93,7 +93,7 @@ func valToStr(c *C, iter kv.Iterator) string { func checkSeek(c *C, txn kv.Transaction) { for i := startIndex; i < testCount; i++ { val := encodeInt(i) - iter, err := txn.Seek(val, nil) + iter, err := txn.Seek(val) c.Assert(err, IsNil) c.Assert(iter.Key(), Equals, string(val)) c.Assert(decodeInt([]byte(valToStr(c, iter))), Equals, i) @@ -103,12 +103,12 @@ func checkSeek(c *C, txn kv.Transaction) { // Test iterator Next() for i := startIndex; i < testCount-1; i++ { val := encodeInt(i) - iter, err := txn.Seek(val, nil) + iter, err := txn.Seek(val) c.Assert(err, IsNil) c.Assert(iter.Key(), Equals, string(val)) c.Assert(valToStr(c, iter), Equals, string(val)) - next, err := iter.Next(nil) + next, err := iter.Next() c.Assert(err, IsNil) c.Assert(next.Valid(), IsTrue) @@ -119,7 +119,7 @@ func checkSeek(c *C, txn kv.Transaction) { } // Non exist seek test - iter, err := txn.Seek(encodeInt(testCount), nil) + iter, err := txn.Seek(encodeInt(testCount)) c.Assert(err, IsNil) c.Assert(iter.Valid(), IsFalse) iter.Close() @@ -264,19 +264,19 @@ func (s *testKVSuite) TestDelete2(c *C) { txn, err = s.s.Begin() c.Assert(err, IsNil) - it, err := txn.Seek([]byte("DATA_test_tbl_department_record__0000000001_0003"), nil) + it, err := txn.Seek([]byte("DATA_test_tbl_department_record__0000000001_0003")) c.Assert(err, IsNil) for it.Valid() { err = txn.Delete([]byte(it.Key())) c.Assert(err, IsNil) - it, err = it.Next(nil) + it, err = it.Next() c.Assert(err, IsNil) } txn.Commit() txn, err = s.s.Begin() c.Assert(err, IsNil) - it, _ = txn.Seek([]byte("DATA_test_tbl_department_record__000000000"), nil) + it, _ = txn.Seek([]byte("DATA_test_tbl_department_record__000000000")) c.Assert(it.Valid(), IsFalse) txn.Commit() @@ -299,7 +299,7 @@ func (s *testKVSuite) TestBasicSeek(c *C) { c.Assert(err, IsNil) defer txn.Commit() - it, err := txn.Seek([]byte("2"), nil) + it, err := txn.Seek([]byte("2")) c.Assert(err, IsNil) c.Assert(it.Valid(), Equals, false) txn.Delete([]byte("1")) @@ -320,30 +320,30 @@ func (s *testKVSuite) TestBasicTable(c *C) { err = txn.Set([]byte("1"), []byte("1")) c.Assert(err, IsNil) - it, err := txn.Seek([]byte("0"), nil) + it, err := txn.Seek([]byte("0")) c.Assert(err, IsNil) c.Assert(it.Key(), Equals, "1") err = txn.Set([]byte("0"), []byte("0")) c.Assert(err, IsNil) - it, err = txn.Seek([]byte("0"), nil) + it, err = txn.Seek([]byte("0")) c.Assert(err, IsNil) c.Assert(it.Key(), Equals, "0") err = txn.Delete([]byte("0")) c.Assert(err, IsNil) txn.Delete([]byte("1")) - it, err = txn.Seek([]byte("0"), nil) + it, err = txn.Seek([]byte("0")) c.Assert(err, IsNil) c.Assert(it.Key(), Equals, "2") err = txn.Delete([]byte("3")) c.Assert(err, IsNil) - it, err = txn.Seek([]byte("2"), nil) + it, err = txn.Seek([]byte("2")) c.Assert(err, IsNil) c.Assert(it.Key(), Equals, "2") - it, err = txn.Seek([]byte("3"), nil) + it, err = txn.Seek([]byte("3")) c.Assert(err, IsNil) c.Assert(it.Key(), Equals, "4") err = txn.Delete([]byte("2")) @@ -401,13 +401,13 @@ func (s *testKVSuite) TestSeekMin(c *C) { txn.Set([]byte(kv.key), []byte(kv.value)) } - it, err := txn.Seek(nil, nil) + it, err := txn.Seek(nil) for it.Valid() { fmt.Printf("%s, %s\n", it.Key(), it.Value()) - it, _ = it.Next(nil) + it, _ = it.Next() } - it, err = txn.Seek([]byte("DATA_test_main_db_tbl_tbl_test_record__00000000000000000000"), nil) + it, err = txn.Seek([]byte("DATA_test_main_db_tbl_tbl_test_record__00000000000000000000")) c.Assert(err, IsNil) c.Assert(string(it.Key()), Equals, "DATA_test_main_db_tbl_tbl_test_record__00000000000000000001") diff --git a/store/localstore/mvcc_test.go b/store/localstore/mvcc_test.go index db187e1039..9de9682e19 100644 --- a/store/localstore/mvcc_test.go +++ b/store/localstore/mvcc_test.go @@ -138,11 +138,11 @@ func (t *testMvccSuite) TestMvccPutAndDel(c *C) { func (t *testMvccSuite) TestMvccNext(c *C) { txn, _ := t.s.Begin() - it, err := txn.Seek(encodeInt(2), nil) + it, err := txn.Seek(encodeInt(2)) c.Assert(err, IsNil) c.Assert(it.Valid(), IsTrue) for it.Valid() { - it, err = it.Next(nil) + it, err = it.Next() c.Assert(err, IsNil) } txn.Commit() @@ -199,7 +199,7 @@ func (t *testMvccSuite) TestMvccSuiteGetLatest(c *C) { c.Assert(err, IsNil) c.Assert(string(b), Equals, string(encodeInt(100+9))) // we can always scan newest data - it, err := tx.Seek(encodeInt(5), nil) + it, err := tx.Seek(encodeInt(5)) c.Assert(err, IsNil) c.Assert(it.Valid(), IsTrue) c.Assert(string(it.Value()), Equals, string(encodeInt(100+9))) @@ -244,7 +244,7 @@ func (t *testMvccSuite) TestMvccSnapshotScan(c *C) { if string(it.Value()) == "new" { found = true } - it, err = it.Next(nil) + it, err = it.Next() c.Assert(err, IsNil) } return found @@ -271,11 +271,11 @@ func (t *testMvccSuite) TestBufferedIterator(c *C) { tx.Commit() tx, _ = s.Begin() - iter, err := tx.Seek([]byte{0}, nil) + iter, err := tx.Seek([]byte{0}) c.Assert(err, IsNil) cnt := 0 for iter.Valid() { - iter, err = iter.Next(nil) + iter, err = iter.Next() c.Assert(err, IsNil) cnt++ } @@ -283,7 +283,7 @@ func (t *testMvccSuite) TestBufferedIterator(c *C) { c.Assert(cnt, Equals, 6) tx, _ = s.Begin() - it, err := tx.Seek([]byte{0xff, 0xee}, nil) + it, err := tx.Seek([]byte{0xff, 0xee}) c.Assert(err, IsNil) c.Assert(it.Valid(), IsTrue) c.Assert(it.Key(), Equals, "\xff\xff\xee\xff") @@ -291,11 +291,11 @@ func (t *testMvccSuite) TestBufferedIterator(c *C) { // no such key tx, _ = s.Begin() - it, err = tx.Seek([]byte{0xff, 0xff, 0xff, 0xff}, nil) + it, err = tx.Seek([]byte{0xff, 0xff, 0xff, 0xff}) c.Assert(err, IsNil) c.Assert(it.Valid(), IsFalse) - it, err = tx.Seek([]byte{0x0, 0xff}, nil) + it, err = tx.Seek([]byte{0x0, 0xff}) c.Assert(err, IsNil) c.Assert(it.Valid(), IsTrue) c.Assert(it.Value(), DeepEquals, []byte("2")) diff --git a/store/localstore/snapshot.go b/store/localstore/snapshot.go index ff53e90511..2340001520 100644 --- a/store/localstore/snapshot.go +++ b/store/localstore/snapshot.go @@ -136,11 +136,11 @@ func newDBIter(s *dbSnapshot, startKey kv.Key, exceptedVer kv.Version) *dbIter { valid: true, exceptedVersion: exceptedVer, } - it.Next(nil) + it.Next() return it } -func (it *dbIter) Next(fn kv.FnKeyCmp) (kv.Iterator, error) { +func (it *dbIter) Next() (kv.Iterator, error) { encKey := codec.EncodeBytes(nil, it.startKey) var retErr error var engineIter engine.Iterator diff --git a/store/localstore/txn.go b/store/localstore/txn.go index 3dd6524ade..9725501685 100644 --- a/store/localstore/txn.go +++ b/store/localstore/txn.go @@ -148,7 +148,7 @@ func (txn *dbTxn) Set(k kv.Key, data []byte) error { return nil } -func (txn *dbTxn) Seek(k kv.Key, fnKeyCmp func(kv.Key) bool) (kv.Iterator, error) { +func (txn *dbTxn) Seek(k kv.Key) (kv.Iterator, error) { log.Debugf("seek key:%q, txn:%d", k, txn.tid) k = kv.EncodeKey(k) @@ -160,12 +160,6 @@ func (txn *dbTxn) Seek(k kv.Key, fnKeyCmp func(kv.Key) bool) (kv.Iterator, error return &kv.UnionIter{}, nil } - if fnKeyCmp != nil { - if fnKeyCmp([]byte(iter.Key())[:1]) { - return &kv.UnionIter{}, nil - } - } - return iter, nil } diff --git a/table/tables/tables.go b/table/tables/tables.go index fe6184ee69..e8d90d46b5 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -530,7 +530,7 @@ func (t *Table) IterRecords(ctx context.Context, startKey string, cols []*column return err } - it, err := txn.Seek([]byte(startKey), nil) + it, err := txn.Seek([]byte(startKey)) if err != nil { return err } diff --git a/util/prefix_helper.go b/util/prefix_helper.go index 12aa114962..b8f40559c2 100644 --- a/util/prefix_helper.go +++ b/util/prefix_helper.go @@ -36,7 +36,7 @@ func hasPrefix(prefix []byte) kv.FnKeyCmp { // ScanMetaWithPrefix scans metadata with the prefix. func ScanMetaWithPrefix(txn kv.Transaction, prefix string, filter func([]byte, []byte) bool) error { - iter, err := txn.Seek([]byte(prefix), hasPrefix([]byte(prefix))) + iter, err := txn.Seek([]byte(prefix)) if err != nil { return errors.Trace(err) } @@ -51,7 +51,7 @@ func ScanMetaWithPrefix(txn kv.Transaction, prefix string, filter func([]byte, [ if !filter([]byte(iter.Key()), iter.Value()) { break } - iter, err = iter.Next(hasPrefix([]byte(prefix))) + iter, err = iter.Next() } else { break } @@ -68,7 +68,7 @@ func DelKeyWithPrefix(ctx context.Context, prefix string) error { } var keys []string - iter, err := txn.Seek([]byte(prefix), hasPrefix([]byte(prefix))) + iter, err := txn.Seek([]byte(prefix)) if err != nil { return errors.Trace(err) } @@ -81,7 +81,7 @@ func DelKeyWithPrefix(ctx context.Context, prefix string) error { if iter.Valid() && strings.HasPrefix(iter.Key(), prefix) { keys = append(keys, iter.Key()) - iter, err = iter.Next(hasPrefix([]byte(prefix))) + iter, err = iter.Next() } else { break } From 2aa4cad1d292b659ceeee4a8f795c98e38f28e52 Mon Sep 17 00:00:00 2001 From: ngaut Date: Mon, 19 Oct 2015 20:10:28 +0800 Subject: [PATCH 54/54] *: Tiny clean up --- kv/index_iter.go | 6 ------ util/prefix_helper.go | 6 ------ util/prefix_helper_test.go | 5 +---- 3 files changed, 1 insertion(+), 16 deletions(-) diff --git a/kv/index_iter.go b/kv/index_iter.go index 000c9f13ea..e19004ddeb 100644 --- a/kv/index_iter.go +++ b/kv/index_iter.go @@ -189,12 +189,6 @@ func (c *kvIndex) Delete(txn Transaction, indexedValues []interface{}, h int64) return errors.Trace(err) } -func hasPrefix(prefix []byte) FnKeyCmp { - return func(k Key) bool { - return bytes.HasPrefix([]byte(k), prefix) - } -} - // Drop removes the KV index from store. func (c *kvIndex) Drop(txn Transaction) error { prefix := []byte(c.prefix) diff --git a/util/prefix_helper.go b/util/prefix_helper.go index b8f40559c2..78294a1a40 100644 --- a/util/prefix_helper.go +++ b/util/prefix_helper.go @@ -28,12 +28,6 @@ import ( "github.com/pingcap/tidb/util/codec" ) -func hasPrefix(prefix []byte) kv.FnKeyCmp { - return func(k kv.Key) bool { - return bytes.HasPrefix(k, prefix) - } -} - // ScanMetaWithPrefix scans metadata with the prefix. func ScanMetaWithPrefix(txn kv.Transaction, prefix string, filter func([]byte, []byte) bool) error { iter, err := txn.Seek([]byte(prefix)) diff --git a/util/prefix_helper_test.go b/util/prefix_helper_test.go index bdb9340614..18e00e1366 100644 --- a/util/prefix_helper_test.go +++ b/util/prefix_helper_test.go @@ -161,11 +161,8 @@ func (s *testPrefixSuite) TestCode(c *C) { b1 := EncodeRecordKey("aa", 1, 0) b2 := EncodeRecordKey("aa", 1, 1) c.Logf("%#v, %#v", b2, b1) - raw, err := codec.StripEnd(b1) + _, err := codec.StripEnd(b1) c.Assert(err, IsNil) - f := hasPrefix(raw) - has := f(b2) - c.Assert(has, IsTrue) } func (s *testPrefixSuite) TestPrefixFilter(c *C) {