*: merge master

This commit is contained in:
siddontang
2015-10-20 11:06:22 +08:00
143 changed files with 2105 additions and 714 deletions

View File

@ -22,7 +22,7 @@ import (
"runtime/debug"
"github.com/ngaut/log"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/errors"
"github.com/pingcap/tidb/util/errors2"
)

View File

@ -24,7 +24,7 @@ import (
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)

View File

@ -18,7 +18,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)

View File

@ -30,7 +30,7 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/coldef"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/table/tables"
@ -512,7 +512,7 @@ func updateOldRows(ctx context.Context, t *tables.Table, col *column.Col) error
if err != nil {
return errors.Trace(err)
}
it, err := txn.Seek([]byte(t.FirstKey()), nil)
it, err := txn.Seek([]byte(t.FirstKey()))
if err != nil {
return errors.Trace(err)
}
@ -685,7 +685,7 @@ func (d *ddl) buildIndex(ctx context.Context, t table.Table, idxInfo *model.Inde
if err != nil {
return errors.Trace(err)
}
it, err := txn.Seek([]byte(firstKey), nil)
it, err := txn.Seek([]byte(firstKey))
if err != nil {
return errors.Trace(err)
}

View File

@ -23,7 +23,7 @@ import (
"github.com/pingcap/tidb/ddl"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/stmt"

View File

@ -31,7 +31,7 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/rset"
"github.com/pingcap/tidb/sessionctx"
qerror "github.com/pingcap/tidb/util/errors"

View File

@ -23,7 +23,7 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)

View File

@ -20,7 +20,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/expression/builtin"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
@ -186,7 +186,7 @@ func (s *testBinOpSuite) TestIdentRelOp(c *C) {
f := func(name string) *Ident {
return &Ident{
model.NewCIStr(name),
CIStr: model.NewCIStr(name),
}
}

View File

@ -24,7 +24,7 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/kv/memkv"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)

View File

@ -15,7 +15,7 @@ package builtin
import (
. "github.com/pingcap/check"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)

View File

@ -19,7 +19,7 @@ import (
"time"
. "github.com/pingcap/check"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
)
func (s *testBuiltinSuite) TestLength(c *C) {

View File

@ -21,7 +21,7 @@ import (
"time"
"github.com/juju/errors"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)

View File

@ -18,7 +18,7 @@ import (
"time"
. "github.com/pingcap/check"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
)
func (s *testBuiltinSuite) TestDate(c *C) {

View File

@ -18,7 +18,7 @@ import (
"github.com/pingcap/tidb/context"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)

View File

@ -17,7 +17,7 @@ import (
"errors"
. "github.com/pingcap/check"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/types"
)

104
expression/date_add.go Normal file
View File

@ -0,0 +1,104 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"fmt"
"strings"
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)
// DateAdd is for time date_add function.
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
type DateAdd struct {
Unit string
Date Expression
Interval Expression
}
// Clone implements the Expression Clone interface.
func (da *DateAdd) Clone() Expression {
n := *da
return &n
}
// Eval implements the Expression Eval interface.
func (da *DateAdd) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
dv, err := da.Date.Eval(ctx, args)
if dv == nil || err != nil {
return nil, errors.Trace(err)
}
sv, err := types.ToString(dv)
if err != nil {
return nil, errors.Trace(err)
}
f := types.NewFieldType(mysql.TypeDatetime)
f.Decimal = mysql.MaxFsp
dv, err = types.Convert(sv, f)
if dv == nil || err != nil {
return nil, errors.Trace(err)
}
t, ok := dv.(mysql.Time)
if !ok {
return nil, errors.Errorf("need time type, but got %T", dv)
}
iv, err := da.Interval.Eval(ctx, args)
if iv == nil || err != nil {
return nil, errors.Trace(err)
}
format, err := types.ToString(iv)
if err != nil {
return nil, errors.Trace(err)
}
years, months, days, durations, err := mysql.ExtractTimeValue(da.Unit, strings.TrimSpace(format))
if err != nil {
return nil, errors.Trace(err)
}
t.Time = t.Time.Add(durations)
t.Time = t.Time.AddDate(int(years), int(months), int(days))
// "2011-11-11 10:10:20.000000" outputs "2011-11-11 10:10:20".
if t.Time.Nanosecond() == 0 {
t.Fsp = 0
}
return t, nil
}
// IsStatic implements the Expression IsStatic interface.
func (da *DateAdd) IsStatic() bool {
return da.Date.IsStatic() && da.Interval.IsStatic()
}
// String implements the Expression String interface.
func (da *DateAdd) String() string {
return fmt.Sprintf("DATE_ADD(%s, INTERVAL %s %s)", da.Date, da.Interval, strings.ToUpper(da.Unit))
}
// Accept implements the Visitor Accept interface.
func (da *DateAdd) Accept(v Visitor) (Expression, error) {
return v.VisitDateAdd(da)
}

144
expression/date_add_test.go Normal file
View File

@ -0,0 +1,144 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/mysql"
)
var _ = Suite(&testDateAddSuite{})
type testDateAddSuite struct {
}
func (t *testDateAddSuite) TestDateAdd(c *C) {
input := "2011-11-11 10:10:10"
e := &DateAdd{
Unit: "DAY",
Date: Value{Val: input},
Interval: Value{Val: "1"},
}
c.Assert(e.String(), Equals, `DATE_ADD("2011-11-11 10:10:10", INTERVAL "1" DAY)`)
c.Assert(e.Clone(), NotNil)
c.Assert(e.IsStatic(), IsTrue)
_, err := e.Eval(nil, nil)
c.Assert(err, IsNil)
// Test null.
e = &DateAdd{
Unit: "DAY",
Date: Value{Val: nil},
Interval: Value{Val: "1"},
}
v, err := e.Eval(nil, nil)
c.Assert(err, IsNil)
c.Assert(v, IsNil)
e = &DateAdd{
Unit: "DAY",
Date: Value{Val: input},
Interval: Value{Val: nil},
}
v, err = e.Eval(nil, nil)
c.Assert(err, IsNil)
c.Assert(v, IsNil)
// Test eval.
tbl := []struct {
Unit string
Interval interface{}
Expect string
}{
{"MICROSECOND", "1000", "2011-11-11 10:10:10.001000"},
{"MICROSECOND", 1000, "2011-11-11 10:10:10.001000"},
{"SECOND", "10", "2011-11-11 10:10:20"},
{"MINUTE", "10", "2011-11-11 10:20:10"},
{"HOUR", "10", "2011-11-11 20:10:10"},
{"DAY", "11", "2011-11-22 10:10:10"},
{"WEEK", "2", "2011-11-25 10:10:10"},
{"MONTH", "2", "2012-01-11 10:10:10"},
{"QUARTER", "4", "2012-11-11 10:10:10"},
{"YEAR", "2", "2013-11-11 10:10:10"},
{"SECOND_MICROSECOND", "10.00100000", "2011-11-11 10:10:20.100000"},
{"SECOND_MICROSECOND", "10.0010000000", "2011-11-11 10:10:30"},
{"SECOND_MICROSECOND", "10.0010000010", "2011-11-11 10:10:30.000010"},
{"MINUTE_MICROSECOND", "10:10.100", "2011-11-11 10:20:20.100000"},
{"MINUTE_SECOND", "10:10", "2011-11-11 10:20:20"},
{"HOUR_MICROSECOND", "10:10:10.100", "2011-11-11 20:20:20.100000"},
{"HOUR_SECOND", "10:10:10", "2011-11-11 20:20:20"},
{"HOUR_MINUTE", "10:10", "2011-11-11 20:20:10"},
{"DAY_MICROSECOND", "11 10:10:10.100", "2011-11-22 20:20:20.100000"},
{"DAY_SECOND", "11 10:10:10", "2011-11-22 20:20:20"},
{"DAY_MINUTE", "11 10:10", "2011-11-22 20:20:10"},
{"DAY_HOUR", "11 10", "2011-11-22 20:10:10"},
{"YEAR_MONTH", "11-1", "2022-12-11 10:10:10"},
{"YEAR_MONTH", "11-11", "2023-10-11 10:10:10"},
}
for _, t := range tbl {
e := &DateAdd{
Unit: t.Unit,
Date: Value{Val: input},
Interval: Value{Val: t.Interval},
}
v, err := e.Eval(nil, nil)
c.Assert(err, IsNil)
value, ok := v.(mysql.Time)
c.Assert(ok, IsTrue)
c.Assert(value.String(), Equals, t.Expect)
}
// Test error.
errInput := "20111111 10:10:10"
errTbl := []struct {
Unit string
Interval interface{}
}{
{"MICROSECOND", "abc1000"},
{"MICROSECOND", ""},
{"SECOND_MICROSECOND", "10"},
{"MINUTE_MICROSECOND", "10.0000"},
{"MINUTE_MICROSECOND", "10:10:10.0000"},
// MySQL support, but tidb not.
{"HOUR_MICROSECOND", "10:10.0000"},
{"YEAR_MONTH", "10 1"},
}
for _, t := range errTbl {
e := &DateAdd{
Unit: t.Unit,
Date: Value{Val: input},
Interval: Value{Val: t.Interval},
}
_, err := e.Eval(nil, nil)
c.Assert(err, NotNil)
e = &DateAdd{
Unit: t.Unit,
Date: Value{Val: errInput},
Interval: Value{Val: t.Interval},
}
v, err := e.Eval(nil, nil)
c.Assert(err, NotNil, Commentf("%s", v))
}
}

View File

@ -19,12 +19,12 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)
// Extract is for time extract function.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract
type Extract struct {
Unit string
Date Expression
@ -56,7 +56,7 @@ func (e *Extract) Eval(ctx context.Context, args map[interface{}]interface{}) (i
return nil, errors.Errorf("need time type, but got %T", v)
}
n, err1 := extractTime(e.Unit, t)
n, err1 := mysql.ExtractTimeNum(e.Unit, t)
if err1 != nil {
return nil, errors.Trace(err1)
}
@ -78,70 +78,3 @@ func (e *Extract) String() string {
func (e *Extract) Accept(v Visitor) (Expression, error) {
return v.VisitExtract(e)
}
func extractTime(unit string, t mysql.Time) (int64, error) {
switch strings.ToUpper(unit) {
case "MICROSECOND":
return int64(t.Nanosecond() / 1000), nil
case "SECOND":
return int64(t.Second()), nil
case "MINUTE":
return int64(t.Minute()), nil
case "HOUR":
return int64(t.Hour()), nil
case "DAY":
return int64(t.Day()), nil
case "WEEK":
_, week := t.ISOWeek()
return int64(week), nil
case "MONTH":
return int64(t.Month()), nil
case "QUARTER":
m := int64(t.Month())
// 1 - 3 -> 1
// 4 - 6 -> 2
// 7 - 9 -> 3
// 10 - 12 -> 4
return (m + 2) / 3, nil
case "YEAR":
return int64(t.Year()), nil
case "SECOND_MICROSECOND":
return int64(t.Second())*1000000 + int64(t.Nanosecond())/1000, nil
case "MINUTE_MICROSECOND":
_, m, s := t.Clock()
return int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil
case "MINUTE_SECOND":
_, m, s := t.Clock()
return int64(m*100 + s), nil
case "HOUR_MICROSECOND":
h, m, s := t.Clock()
return int64(h)*10000000000 + int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil
case "HOUR_SECOND":
h, m, s := t.Clock()
return int64(h)*10000 + int64(m)*100 + int64(s), nil
case "HOUR_MINUTE":
h, m, _ := t.Clock()
return int64(h)*100 + int64(m), nil
case "DAY_MICROSECOND":
h, m, s := t.Clock()
d := t.Day()
return int64(d*1000000+h*10000+m*100+s)*1000000 + int64(t.Nanosecond())/1000, nil
case "DAY_SECOND":
h, m, s := t.Clock()
d := t.Day()
return int64(d)*1000000 + int64(h)*10000 + int64(m)*100 + int64(s), nil
case "DAY_MINUTE":
h, m, _ := t.Clock()
d := t.Day()
return int64(d)*10000 + int64(h)*100 + int64(m), nil
case "DAY_HOUR":
h, _, _ := t.Clock()
d := t.Day()
return int64(d)*100 + int64(h), nil
case "YEAR_MONTH":
y, m, _ := t.Date()
return int64(y)*100 + int64(m), nil
default:
return 0, errors.Errorf("invalid unit %s", unit)
}
}

View File

@ -30,7 +30,7 @@ import (
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression/builtin"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/types"
@ -45,13 +45,15 @@ const (
ExprEvalPositionFunc = "$positionFunc"
// ExprEvalValuesFunc is the key saving a function to retrieve value for column name.
ExprEvalValuesFunc = "$valuesFunc"
// ExprEvalIdentReferFunc is the key saving a function to retrieve value with identifier reference index.
ExprEvalIdentReferFunc = "$identReferFunc"
)
var (
// CurrentTimestamp is the keyword getting default value for datetime and timestamp type.
CurrentTimestamp = "CURRENT_TIMESTAMP"
// CurrentTimeExpr is the expression retireving default value for datetime and timestamp type.
CurrentTimeExpr = &Ident{model.NewCIStr(CurrentTimestamp)}
CurrentTimeExpr = &Ident{CIStr: model.NewCIStr(CurrentTimestamp)}
// ZeroTimestamp shows the zero datetime and timestamp.
ZeroTimestamp = "0000-00-00 00:00:00"
)
@ -142,22 +144,40 @@ func newMentionedAggregateFuncsVisitor() *mentionedAggregateFuncsVisitor {
}
func (v *mentionedAggregateFuncsVisitor) VisitCall(c *Call) (Expression, error) {
isAggregate := IsAggregateFunc(c.F)
if isAggregate {
v.exprs = append(v.exprs, c)
}
n := len(v.exprs)
for _, e := range c.Args {
_, err := e.Accept(v)
if err != nil {
return nil, errors.Trace(err)
}
}
f, ok := builtin.Funcs[strings.ToLower(c.F)]
if !ok {
return nil, errors.Errorf("unknown function %s", c.F)
}
if f.IsAggregate {
v.exprs = append(v.exprs, c)
if isAggregate && len(v.exprs) != n {
// aggregate function can't use aggregate function as the arg.
// here means we have aggregate function in arg.
return nil, errors.Errorf("Invalid use of group function")
}
return c, nil
}
// IsAggregateFunc checks whether name is an aggregate function or not.
func IsAggregateFunc(name string) bool {
// TODO: use switch defined aggregate name "sum", "count", etc... directly.
// Maybe we can remove builtin IsAggregate field later.
f, ok := builtin.Funcs[strings.ToLower(name)]
if !ok {
return false
}
return f.IsAggregate
}
// MentionedColumns returns a list of names for Ident expression.
func MentionedColumns(e Expression) []string {
var names []string

View File

@ -5,9 +5,8 @@ import (
"time"
. "github.com/pingcap/check"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/sessionctx/variable"
)
@ -47,6 +46,15 @@ func (s *testHelperSuite) TestContainAggFunc(c *C) {
b := ContainAggregateFunc(t.Expr)
c.Assert(b, Equals, t.Expect)
}
expr := &Call{
F: "count",
Args: []Expression{
&Call{F: "count", Args: []Expression{v}},
},
}
_, err := MentionedAggregateFuncs(expr)
c.Assert(err, NotNil)
}
func (s *testHelperSuite) TestMentionedColumns(c *C) {
@ -57,7 +65,7 @@ func (s *testHelperSuite) TestMentionedColumns(c *C) {
}{
{Value{1}, 0},
{&BinaryOperation{L: v, R: v}, 0},
{&Ident{model.NewCIStr("id")}, 1},
{&Ident{CIStr: model.NewCIStr("id")}, 1},
{&Call{F: "count", Args: []Expression{v}}, 0},
{&IsNull{Expr: v}, 0},
{&PExpr{Expr: v}, 0},
@ -106,7 +114,7 @@ func (s *testHelperSuite) TestBase(c *C) {
{int64(1), int64(1)},
{&UnaryOperation{Op: opcode.Plus, V: Value{1}}, 1},
{&UnaryOperation{Op: opcode.Not, V: Value{1}}, nil},
{&UnaryOperation{Op: opcode.Plus, V: &Ident{model.NewCIStr("id")}}, nil},
{&UnaryOperation{Op: opcode.Plus, V: &Ident{CIStr: model.NewCIStr("id")}}, nil},
{nil, nil},
}
@ -228,7 +236,7 @@ func (s *testHelperSuite) TestGetTimeValue(c *C) {
{Value{"2012-13-12 00:00:00"}},
{Value{0}},
{Value{int64(1)}},
{&Ident{model.NewCIStr("xxx")}},
{&Ident{CIStr: model.NewCIStr("xxx")}},
{NewUnaryOperation(opcode.Minus, Value{int64(1)})},
}

View File

@ -29,10 +29,23 @@ var (
_ Expression = (*Ident)(nil)
)
const (
// IdentReferSelectList means the identifier reference is in select list.
IdentReferSelectList = 1
// IdentReferFromTable means the identifier reference is in FROM table.
IdentReferFromTable = 2
)
// Ident is the identifier expression.
type Ident struct {
// model.CIStr contains origin identifier name and its lowercase name.
model.CIStr
// ReferScope means where the identifer reference is, select list or from.
ReferScope int
// ReferIndex is the index to get the identifer data.
ReferIndex int
}
// Clone implements the Expression Clone interface.
@ -63,6 +76,14 @@ func (i *Ident) Eval(ctx context.Context, args map[interface{}]interface{}) (v i
return nil, nil
}
// TODO: we will unify ExprEvalIdentReferFunc and ExprEvalIdentFunc later,
// now just put them here for refactor step by step.
if f, ok := args[ExprEvalIdentReferFunc]; ok {
if got, ok := f.(func(string, int, int) (interface{}, error)); ok {
return got(i.L, i.ReferScope, i.ReferIndex)
}
}
if f, ok := args[ExprEvalIdentFunc]; ok {
if got, ok := f.(func(string) (interface{}, error)); ok {
return got(i.L)

View File

@ -26,7 +26,7 @@ type testIdentSuite struct {
func (s *testIdentSuite) TestIdent(c *C) {
e := Ident{
model.NewCIStr("id"),
CIStr: model.NewCIStr("id"),
}
c.Assert(e.IsStatic(), IsFalse)
@ -55,4 +55,14 @@ func (s *testIdentSuite) TestIdent(c *C) {
v, err = e.Eval(nil, m)
c.Assert(err, IsNil)
c.Assert(v, Equals, 1)
delete(m, ExprEvalIdentFunc)
e.ReferScope = IdentReferSelectList
e.ReferIndex = 1
m[ExprEvalIdentReferFunc] = func(string, int, int) (interface{}, error) {
return 2, nil
}
v, err = e.Eval(nil, m)
c.Assert(err, IsNil)
c.Assert(v, Equals, 2)
}

View File

@ -23,7 +23,7 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)

View File

@ -19,7 +19,7 @@ import (
. "github.com/pingcap/check"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)

View File

@ -106,6 +106,9 @@ type Visitor interface {
// VisitFunctionTrim visits FunctionTrim expression.
VisitFunctionTrim(v *FunctionTrim) (Expression, error)
// VisitDateAdd visits DateAdd expression.
VisitDateAdd(da *DateAdd) (Expression, error)
}
// BaseVisitor is the base implementation of Visitor.
@ -456,7 +459,7 @@ func (bv *BaseVisitor) VisitWhenClause(w *WhenClause) (Expression, error) {
return w, nil
}
// VisitExtract implements Visitor
// VisitExtract implements Visitor interface.
func (bv *BaseVisitor) VisitExtract(v *Extract) (Expression, error) {
var err error
v.Date, err = v.Date.Accept(bv.V)
@ -478,3 +481,19 @@ func (bv *BaseVisitor) VisitFunctionTrim(ss *FunctionTrim) (Expression, error) {
}
return ss, nil
}
// VisitDateAdd implements Visitor interface.
func (bv *BaseVisitor) VisitDateAdd(da *DateAdd) (Expression, error) {
var err error
da.Date, err = da.Date.Accept(bv.V)
if err != nil {
return da, errors.Trace(err)
}
da.Interval, err = da.Interval.Accept(bv.V)
if err != nil {
return da, errors.Trace(err)
}
return da, nil
}

View File

@ -17,7 +17,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)

View File

@ -19,7 +19,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/field"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)

View File

@ -23,7 +23,7 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/column"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
)
const (

View File

@ -19,7 +19,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)

View File

@ -19,7 +19,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/store/localstore"
"github.com/pingcap/tidb/store/localstore/goleveldb"
"github.com/pingcap/tidb/util/types"
@ -51,7 +51,7 @@ func (*testSuite) TestT(c *C) {
ID: 3,
Name: colName,
Offset: 0,
FieldType: *types.NewFieldType(mysqldef.TypeLonglong),
FieldType: *types.NewFieldType(mysql.TypeLonglong),
}
idxInfo := &model.IndexInfo{

40
kv/compactor.go Normal file
View File

@ -0,0 +1,40 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package kv
import "time"
// CompactPolicy defines gc policy of MVCC storage.
type CompactPolicy struct {
// SafePoint specifies
SafePoint int
// TriggerInterval specifies how often should the compactor
// scans outdated data.
TriggerInterval time.Duration
// BatchDeleteCnt specifies the batch size for
// deleting outdated data transaction.
BatchDeleteCnt int
}
// Compactor compacts MVCC storage.
type Compactor interface {
// OnGet is the hook point on Txn.Get.
OnGet(k Key)
// OnSet is the hook point on Txn.Set.
OnSet(k Key)
// OnDelete is the hook point on Txn.Delete.
OnDelete(k Key)
// Compact is the function removes the given key.
Compact(ctx interface{}, k Key) error
}

View File

@ -87,7 +87,7 @@ func (c *indexIter) Next() (k []interface{}, h int64, err error) {
k = vv
}
// update new iter to next
newIt, err := c.it.Next(hasPrefix([]byte(c.prefix)))
newIt, err := c.it.Next()
if err != nil {
return nil, 0, errors.Trace(err)
}
@ -189,16 +189,10 @@ func (c *kvIndex) Delete(txn Transaction, indexedValues []interface{}, h int64)
return errors.Trace(err)
}
func hasPrefix(prefix []byte) FnKeyCmp {
return func(k Key) bool {
return bytes.HasPrefix([]byte(k), prefix)
}
}
// Drop removes the KV index from store.
func (c *kvIndex) Drop(txn Transaction) error {
prefix := []byte(c.prefix)
it, err := txn.Seek(Key(prefix), hasPrefix(prefix))
it, err := txn.Seek(Key(prefix))
if err != nil {
return errors.Trace(err)
}
@ -213,7 +207,7 @@ func (c *kvIndex) Drop(txn Transaction) error {
if err != nil {
return errors.Trace(err)
}
it, err = it.Next(hasPrefix(prefix))
it, err = it.Next()
if err != nil {
return errors.Trace(err)
}
@ -227,7 +221,7 @@ func (c *kvIndex) Seek(txn Transaction, indexedValues []interface{}) (iter Index
if err != nil {
return nil, false, errors.Trace(err)
}
it, err := txn.Seek(keyBuf, hasPrefix([]byte(c.prefix)))
it, err := txn.Seek(keyBuf)
if err != nil {
return nil, false, errors.Trace(err)
}
@ -242,7 +236,7 @@ func (c *kvIndex) Seek(txn Transaction, indexedValues []interface{}) (iter Index
// SeekFirst returns an iterator which points to the first entry of the KV index.
func (c *kvIndex) SeekFirst(txn Transaction) (iter IndexIterator, err error) {
prefix := []byte(c.prefix)
it, err := txn.Seek(prefix, hasPrefix(prefix))
it, err := txn.Seek(prefix)
if err != nil {
return nil, errors.Trace(err)
}

View File

@ -72,7 +72,7 @@ func DecodeValue(data []byte) ([]interface{}, error) {
func NextUntil(it Iterator, fn FnKeyCmp) (Iterator, error) {
var err error
for it.Valid() && !fn([]byte(it.Key())) {
it, err = it.Next(nil)
it, err = it.Next()
if err != nil {
return nil, errors.Trace(err)
}

View File

@ -104,7 +104,7 @@ type Transaction interface {
// Set sets the value for key k as v into KV store.
Set(k Key, v []byte) error
// Seek searches for the entry with key k in KV store.
Seek(k Key, fnKeyCmp func(key Key) bool) (Iterator, error)
Seek(k Key) (Iterator, error)
// Inc increases the value for key k in KV store by step.
Inc(k Key, step int64) (int64, error)
// GetInt64 get int64 which created by Inc method.
@ -171,7 +171,7 @@ type FnKeyCmp func(key Key) bool
// Iterator is the interface for a interator on KV store.
type Iterator interface {
Next(FnKeyCmp) (Iterator, error)
Next() (Iterator, error)
Value() []byte
Key() string
Valid() bool

View File

@ -51,7 +51,7 @@ func (iter *UnionIter) dirtyNext() {
// Go next and update valid status.
func (iter *UnionIter) snapshotNext() {
iter.snapshotIt, _ = iter.snapshotIt.Next(nil)
iter.snapshotIt, _ = iter.snapshotIt.Next()
iter.snapshotValid = iter.snapshotIt.Valid()
}
@ -116,7 +116,7 @@ func (iter *UnionIter) updateCur() {
}
// Next implements the Iterator Next interface.
func (iter *UnionIter) Next(f FnKeyCmp) (Iterator, error) {
func (iter *UnionIter) Next() (Iterator, error) {
if !iter.curIsDirty {
iter.snapshotNext()
} else {

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
"fmt"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import . "github.com/pingcap/check"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
// CharsetIDs maps charset name to its default collation ID.
var CharsetIDs = map[string]uint8{

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
// Version informations.
const (

View File

@ -57,7 +57,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
// Decimal implements an arbitrary precision fixed-point decimal.
//

View File

@ -57,7 +57,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
"encoding/json"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
"strconv"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
. "github.com/pingcap/check"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
// MySQL error code.
// This value is numeric. It is not portable to other database systems.

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
// MySQLErrName maps error code to MySQL error messages.
var MySQLErrName = map[uint16]string{

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
"errors"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
. "github.com/pingcap/check"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
"math"
@ -80,3 +80,13 @@ func parseFrac(s string, fsp int) (int, error) {
// 0.0312 round 2 -> 3 -> 30000
return int(round * math.Pow10(MaxFsp-fsp)), nil
}
// alignFrac is used to generate alignment frac, like `100` -> `100000`
func alignFrac(s string, fsp int) string {
sl := len(s)
if sl < fsp {
return s + strings.Repeat("0", fsp-sl)
}
return s
}

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
"encoding/hex"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
"strconv"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
"strconv"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
. "github.com/pingcap/check"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
const (
// DefaultMySQLState is default state of the mySQL

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
"bytes"
@ -979,3 +979,364 @@ func checkTimestamp(t Time) bool {
return true
}
// ExtractTimeNum extracts time value number from time unit and format.
func ExtractTimeNum(unit string, t Time) (int64, error) {
switch strings.ToUpper(unit) {
case "MICROSECOND":
return int64(t.Nanosecond() / 1000), nil
case "SECOND":
return int64(t.Second()), nil
case "MINUTE":
return int64(t.Minute()), nil
case "HOUR":
return int64(t.Hour()), nil
case "DAY":
return int64(t.Day()), nil
case "WEEK":
_, week := t.ISOWeek()
return int64(week), nil
case "MONTH":
return int64(t.Month()), nil
case "QUARTER":
m := int64(t.Month())
// 1 - 3 -> 1
// 4 - 6 -> 2
// 7 - 9 -> 3
// 10 - 12 -> 4
return (m + 2) / 3, nil
case "YEAR":
return int64(t.Year()), nil
case "SECOND_MICROSECOND":
return int64(t.Second())*1000000 + int64(t.Nanosecond())/1000, nil
case "MINUTE_MICROSECOND":
_, m, s := t.Clock()
return int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil
case "MINUTE_SECOND":
_, m, s := t.Clock()
return int64(m*100 + s), nil
case "HOUR_MICROSECOND":
h, m, s := t.Clock()
return int64(h)*10000000000 + int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil
case "HOUR_SECOND":
h, m, s := t.Clock()
return int64(h)*10000 + int64(m)*100 + int64(s), nil
case "HOUR_MINUTE":
h, m, _ := t.Clock()
return int64(h)*100 + int64(m), nil
case "DAY_MICROSECOND":
h, m, s := t.Clock()
d := t.Day()
return int64(d*1000000+h*10000+m*100+s)*1000000 + int64(t.Nanosecond())/1000, nil
case "DAY_SECOND":
h, m, s := t.Clock()
d := t.Day()
return int64(d)*1000000 + int64(h)*10000 + int64(m)*100 + int64(s), nil
case "DAY_MINUTE":
h, m, _ := t.Clock()
d := t.Day()
return int64(d)*10000 + int64(h)*100 + int64(m), nil
case "DAY_HOUR":
h, _, _ := t.Clock()
d := t.Day()
return int64(d)*100 + int64(h), nil
case "YEAR_MONTH":
y, m, _ := t.Date()
return int64(y)*100 + int64(m), nil
default:
return 0, errors.Errorf("invalid unit %s", unit)
}
}
func extractSingleTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) {
iv, err := strconv.ParseInt(format, 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
v := time.Duration(iv)
switch strings.ToUpper(unit) {
case "MICROSECOND":
return 0, 0, 0, v * time.Microsecond, nil
case "SECOND":
return 0, 0, 0, v * time.Second, nil
case "MINUTE":
return 0, 0, 0, v * time.Minute, nil
case "HOUR":
return 0, 0, 0, v * time.Hour, nil
case "DAY":
return 0, 0, iv, 0, nil
case "WEEK":
return 0, 0, 7 * iv, 0, nil
case "MONTH":
return 0, iv, 0, 0, nil
case "QUARTER":
return 0, 3 * iv, 0, 0, nil
case "YEAR":
return iv, 0, 0, 0, nil
}
return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit)
}
// Format is `SS.FFFFFF`.
func extractSecondMicrosecond(format string) (int64, int64, int64, time.Duration, error) {
fields := strings.Split(format, ".")
if len(fields) != 2 {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
seconds, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
microseconds, err := strconv.ParseInt(alignFrac(fields[1], MaxFsp), 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
return 0, 0, 0, time.Duration(seconds)*time.Second + time.Duration(microseconds)*time.Microsecond, nil
}
// Format is `MM:SS.FFFFFF`.
func extractMinuteMicrosecond(format string) (int64, int64, int64, time.Duration, error) {
fields := strings.Split(format, ":")
if len(fields) != 2 {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
minutes, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
_, _, _, value, err := extractSecondMicrosecond(fields[1])
if err != nil {
return 0, 0, 0, 0, errors.Trace(err)
}
return 0, 0, 0, time.Duration(minutes)*time.Minute + value, nil
}
// Format is `MM:SS`.
func extractMinuteSecond(format string) (int64, int64, int64, time.Duration, error) {
fields := strings.Split(format, ":")
if len(fields) != 2 {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
minutes, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
seconds, err := strconv.ParseInt(fields[1], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
return 0, 0, 0, time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil
}
// Format is `HH:MM:SS.FFFFFF`.
func extractHourMicrosecond(format string) (int64, int64, int64, time.Duration, error) {
fields := strings.Split(format, ":")
if len(fields) != 3 {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
hours, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
minutes, err := strconv.ParseInt(fields[1], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
_, _, _, value, err := extractSecondMicrosecond(fields[2])
if err != nil {
return 0, 0, 0, 0, errors.Trace(err)
}
return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + value, nil
}
// Format is `HH:MM:SS`.
func extractHourSecond(format string) (int64, int64, int64, time.Duration, error) {
fields := strings.Split(format, ":")
if len(fields) != 3 {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
hours, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
minutes, err := strconv.ParseInt(fields[1], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
seconds, err := strconv.ParseInt(fields[2], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil
}
// Format is `HH:MM`.
func extractHourMinute(format string) (int64, int64, int64, time.Duration, error) {
fields := strings.Split(format, ":")
if len(fields) != 2 {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
hours, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
minutes, err := strconv.ParseInt(fields[1], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute, nil
}
// Format is `DD HH:MM:SS.FFFFFF`.
func extractDayMicrosecond(format string) (int64, int64, int64, time.Duration, error) {
fields := strings.Split(format, " ")
if len(fields) != 2 {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
days, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
_, _, _, value, err := extractHourMicrosecond(fields[1])
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
return 0, 0, days, value, nil
}
// Format is `DD HH:MM:SS`.
func extractDaySecond(format string) (int64, int64, int64, time.Duration, error) {
fields := strings.Split(format, " ")
if len(fields) != 2 {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
days, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
_, _, _, value, err := extractHourSecond(fields[1])
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
return 0, 0, days, value, nil
}
// Format is `DD HH:MM`.
func extractDayMinute(format string) (int64, int64, int64, time.Duration, error) {
fields := strings.Split(format, " ")
if len(fields) != 2 {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
days, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
_, _, _, value, err := extractHourMinute(fields[1])
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
return 0, 0, days, value, nil
}
// Format is `DD HH`.
func extractDayHour(format string) (int64, int64, int64, time.Duration, error) {
fields := strings.Split(format, " ")
if len(fields) != 2 {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
days, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
hours, err := strconv.ParseInt(fields[1], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
return 0, 0, days, time.Duration(hours) * time.Hour, nil
}
// Format is `YYYY-MM`.
func extractYearMonth(format string) (int64, int64, int64, time.Duration, error) {
fields := strings.Split(format, "-")
if len(fields) != 2 {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
years, err := strconv.ParseInt(fields[0], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
months, err := strconv.ParseInt(fields[1], 10, 64)
if err != nil {
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
}
return years, months, 0, 0, nil
}
// ExtractTimeValue extracts time value from time unit and format.
func ExtractTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) {
switch strings.ToUpper(unit) {
case "MICROSECOND", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR":
return extractSingleTimeValue(unit, format)
case "SECOND_MICROSECOND":
return extractSecondMicrosecond(format)
case "MINUTE_MICROSECOND":
return extractMinuteMicrosecond(format)
case "MINUTE_SECOND":
return extractMinuteSecond(format)
case "HOUR_MICROSECOND":
return extractHourMicrosecond(format)
case "HOUR_SECOND":
return extractHourSecond(format)
case "HOUR_MINUTE":
return extractHourMinute(format)
case "DAY_MICROSECOND":
return extractDayMicrosecond(format)
case "DAY_SECOND":
return extractDaySecond(format)
case "DAY_MINUTE":
return extractDayMinute(format)
case "DAY_HOUR":
return extractDayHour(format)
case "YEAR_MONTH":
return extractYearMonth(format)
default:
return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit)
}
}

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import (
"testing"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
// MySQL type informations.
const (

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import . "github.com/pingcap/check"

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
// GetDefaultFieldLength is used for Interger Types, Flen is the display length.
// Call this when no Flen assigned in ddl.

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package mysqldef
package mysql
import "testing"

View File

@ -21,7 +21,7 @@ import (
"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/types"

View File

@ -18,7 +18,7 @@ import (
"strings"
"github.com/pingcap/tidb/expression"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
)
// FloatOpt is used for parsing floating-point type option from SQL.

View File

@ -29,7 +29,7 @@ import (
"fmt"
"strings"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/parser/coldef"
"github.com/pingcap/tidb/ddl"
@ -113,6 +113,7 @@ import (
currentUser "CURRENT_USER"
database "DATABASE"
databases "DATABASES"
dateAdd "DATE_ADD"
day "DAY"
dayofmonth "DAYOFMONTH"
dayofweek "DAYOFWEEK"
@ -164,6 +165,7 @@ import (
index "INDEX"
inner "INNER"
insert "INSERT"
interval "INTERVAL"
into "INTO"
is "IS"
join "JOIN"
@ -707,7 +709,6 @@ ColumnPosition:
}
}
AlterSpecificationList:
AlterSpecification
{
@ -1702,7 +1703,7 @@ UnReservedKeyword:
| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE"
NotKeywordToken:
"ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT"
"ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT"
| "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS"
| "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK"
@ -2239,6 +2240,14 @@ FunctionCallNonKeyword:
return 1
}
}
| "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')'
{
$$ = &expression.DateAdd{
Unit: $7.(string),
Date: $3.(expression.Expression),
Interval: $6.(expression.Expression),
}
}
| "EXTRACT" '(' TimeUnit "FROM" Expression ')'
{
$$ = &expression.Extract{

View File

@ -403,6 +403,28 @@ func (s *testParserSuite) TestBuiltin(c *C) {
{`SELECT TRIM(LEADING 'x' FROM 'xxxbarxxx');`, true},
{`SELECT TRIM(BOTH 'x' FROM 'xxxbarxxx');`, true},
{`SELECT TRIM(TRAILING 'xyz' FROM 'barxxyz');`, true},
// For date_add
{`select date_add("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval 10 second)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval 10 minute)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval 10 hour)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval 10 day)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval 1 week)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval 1 month)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval 1 quarter)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval 1 year)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true},
}
s.RunTest(c, table)
}

View File

@ -27,7 +27,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/util/stringutil"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
)
type lexer struct {
@ -284,6 +284,7 @@ current_date {c}{u}{r}{r}{e}{n}{t}_{d}{a}{t}{e}
current_user {c}{u}{r}{r}{e}{n}{t}_{u}{s}{e}{r}
database {d}{a}{t}{a}{b}{a}{s}{e}
databases {d}{a}{t}{a}{b}{a}{s}{e}{s}
date_add {d}{a}{t}{e}_{a}{d}{d}
day {d}{a}{y}
dayofweek {d}{a}{y}{o}{f}{w}{e}{e}{k}
dayofmonth {d}{a}{y}{o}{f}{m}{o}{n}{t}{h}
@ -331,6 +332,7 @@ in {i}{n}
index {i}{n}{d}{e}{x}
inner {i}{n}{n}{e}{r}
insert {i}{n}{s}{e}{r}{t}
interval {i}{n}{t}{e}{r}{v}{a}{l}
into {i}{n}{t}{o}
is {i}{s}
join {j}{o}{i}{n}
@ -630,6 +632,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h}
{database} lval.item = string(l.val)
return database
{databases} return databases
{date_add} lval.item = string(l.val)
return dateAdd
{day} lval.item = string(l.val)
return day
{dayofweek} lval.item = string(l.val)
@ -711,6 +715,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h}
{index} return index
{inner} return inner
{insert} return insert
{interval} return interval
{into} return into
{in} return in
{is} return is

View File

@ -64,10 +64,9 @@ func (r *SelectFieldsDefaultPlan) Next(ctx context.Context) (row *plan.Row, err
return nil, errors.Trace(err)
}
r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
v, err0 := GetIdentValue(name, r.Src.GetFields(), srcRow.Data, field.DefaultFieldFlag)
if err0 == nil {
return v, nil
r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
if scope == expression.IdentReferFromTable {
return srcRow.FromData[index], nil
}
return getIdentValueFromOuterQuery(ctx, name)

View File

@ -18,7 +18,7 @@ import (
"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/plan/plans"
"github.com/pingcap/tidb/rset/rsets"
"github.com/pingcap/tidb/util/charset"

View File

@ -68,7 +68,7 @@ func (r *TableNilPlan) Next(ctx context.Context) (row *plan.Row, err error) {
if err != nil {
return nil, errors.Trace(err)
}
r.iter, err = txn.Seek([]byte(r.T.FirstKey()), nil)
r.iter, err = txn.Seek([]byte(r.T.FirstKey()))
if err != nil {
return nil, errors.Trace(err)
}
@ -274,7 +274,7 @@ func (r *TableDefaultPlan) Next(ctx context.Context) (row *plan.Row, err error)
if err != nil {
return nil, errors.Trace(err)
}
r.iter, err = txn.Seek([]byte(r.T.FirstKey()), nil)
r.iter, err = txn.Seek([]byte(r.T.FirstKey()))
if err != nil {
return nil, errors.Trace(err)
}

View File

@ -24,7 +24,7 @@ import (
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/plan/plans"
"github.com/pingcap/tidb/rset/rsets"

View File

@ -89,15 +89,11 @@ type groupRow struct {
func (r *GroupByDefaultPlan) evalGroupKey(ctx context.Context, k []interface{}, outRow []interface{}, in []interface{}) error {
// group by items can not contain aggregate field, so we can eval them safely.
m := map[interface{}]interface{}{}
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag)
if err == nil {
return v, nil
}
v, err = r.getFieldValueByName(name, outRow)
if err == nil {
return v, nil
m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
if scope == expression.IdentReferFromTable {
return in[index], nil
} else if scope == expression.IdentReferSelectList {
return outRow[index], nil
}
// try to find in outer query
@ -132,10 +128,11 @@ func (r *GroupByDefaultPlan) getFieldValueByName(name string, out []interface{})
func (r *GroupByDefaultPlan) evalNoneAggFields(ctx context.Context, out []interface{},
m map[interface{}]interface{}, in []interface{}) error {
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag)
if err == nil {
return v, nil
m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
if scope == expression.IdentReferFromTable {
return in[index], nil
} else if scope == expression.IdentReferSelectList {
return out[index], nil
}
// try to find in outer query
@ -157,42 +154,20 @@ func (r *GroupByDefaultPlan) evalNoneAggFields(ctx context.Context, out []interf
func (r *GroupByDefaultPlan) evalAggFields(ctx context.Context, out []interface{},
m map[interface{}]interface{}, in []interface{}) error {
// Eval aggregate field results in ctx
for i := range r.AggFields {
if i < r.HiddenFieldOffset {
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag)
if err == nil {
return v, nil
}
// try to find in outer query
return getIdentValueFromOuterQuery(ctx, name)
}
} else {
// having may contain aggregate function and we will add it to hidden field,
// and this field can retrieve the data in select list, e.g.
// select c1 as a from t having count(a) = 1
// because all the select list data is generated before, so we can get it
// when handling hidden field.
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag)
if err == nil {
return v, nil
}
// if we can not find in table, we will try to find in un-hidden select list
// only hidden field can use this
v, err = r.getFieldValueByName(name, out)
if err == nil {
return v, nil
}
// try to find in outer query
return getIdentValueFromOuterQuery(ctx, name)
}
m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
if scope == expression.IdentReferFromTable {
return in[index], nil
} else if scope == expression.IdentReferSelectList {
return out[index], nil
}
// try to find in outer query
return getIdentValueFromOuterQuery(ctx, name)
}
// Eval aggregate field results in ctx
for i := range r.AggFields {
// we must evaluate aggregate function only, e.g, select col1 + count(*) in (count(*)),
// we cannot evaluate it directly here, because col1 + count(*) returns nil before AggDone phase,
// so we don't evaluate count(*) in In expression, and will get an invalid data in AggDone phase for it.

View File

@ -42,12 +42,16 @@ func (t *testGroupBySuite) TestGroupBy(c *C) {
Fields: []*field.Field{
{
Expr: &expression.Ident{
CIStr: model.NewCIStr("id"),
CIStr: model.NewCIStr("id"),
ReferScope: expression.IdentReferFromTable,
ReferIndex: 0,
},
},
{
Expr: &expression.Ident{
CIStr: model.NewCIStr("name"),
CIStr: model.NewCIStr("name"),
ReferScope: expression.IdentReferFromTable,
ReferIndex: 1,
},
},
{
@ -55,7 +59,9 @@ func (t *testGroupBySuite) TestGroupBy(c *C) {
F: "sum",
Args: []expression.Expression{
&expression.Ident{
CIStr: model.NewCIStr("id"),
CIStr: model.NewCIStr("id"),
ReferScope: expression.IdentReferFromTable,
ReferIndex: 0,
},
},
},
@ -69,7 +75,9 @@ func (t *testGroupBySuite) TestGroupBy(c *C) {
Src: tblPlan,
By: []expression.Expression{
&expression.Ident{
CIStr: model.NewCIStr("id"),
CIStr: model.NewCIStr("id"),
ReferScope: expression.IdentReferFromTable,
ReferIndex: 0,
},
},
}

View File

@ -62,10 +62,11 @@ func (r *HavingPlan) Next(ctx context.Context) (row *plan.Row, err error) {
if srcRow == nil || err != nil {
return nil, errors.Trace(err)
}
r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
v, err0 := GetIdentValue(name, r.Src.GetFields(), srcRow.Data, field.CheckFieldFlag)
if err0 == nil {
return v, nil
r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
if scope == expression.IdentReferFromTable {
return srcRow.FromData[index], nil
} else if scope == expression.IdentReferSelectList {
return srcRow.Data[index], nil
}
// try to find in outer query

View File

@ -38,7 +38,9 @@ func (t *testHavingPlan) TestHaving(c *C) {
Expr: &expression.BinaryOperation{
Op: opcode.GE,
L: &expression.Ident{
CIStr: model.NewCIStr("id"),
CIStr: model.NewCIStr("id"),
ReferScope: expression.IdentReferSelectList,
ReferIndex: 0,
},
R: &expression.Value{
Val: 20,

View File

@ -24,7 +24,7 @@ import (
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/plan/plans"
"github.com/pingcap/tidb/rset/rsets"

View File

@ -25,7 +25,7 @@ import (
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/charset"

View File

@ -155,6 +155,7 @@ func (r *JoinPlan) Next(ctx context.Context) (row *plan.Row, err error) {
}
func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error) {
visitor := newJoinIdentVisitor(0)
for {
if r.cursor < len(r.matchedRows) {
row = r.matchedRows[r.cursor]
@ -167,7 +168,7 @@ func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error)
return nil, errors.Trace(err)
}
tempExpr := r.On.Clone()
visitor := NewIdentEvalVisitor(r.Left.GetFields(), leftRow.Data)
visitor.row = leftRow.Data
_, err = tempExpr.Accept(visitor)
if err != nil {
return nil, errors.Trace(err)
@ -189,6 +190,7 @@ func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error)
}
func (r *JoinPlan) nextRightJoin(ctx context.Context) (row *plan.Row, err error) {
visitor := newJoinIdentVisitor(len(r.Left.GetFields()))
for {
if r.cursor < len(r.matchedRows) {
row = r.matchedRows[r.cursor]
@ -202,7 +204,7 @@ func (r *JoinPlan) nextRightJoin(ctx context.Context) (row *plan.Row, err error)
}
tempExpr := r.On.Clone()
visitor := NewIdentEvalVisitor(r.Right.GetFields(), rightRow.Data)
visitor.row = rightRow.Data
_, err = tempExpr.Accept(visitor)
if err != nil {
return nil, errors.Trace(err)
@ -251,8 +253,8 @@ func (r *JoinPlan) findMatchedRows(ctx context.Context, row *plan.Row, p plan.Pl
} else {
joined = append(append(joined, row.Data...), cmpRow.Data...)
}
r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
return GetIdentValue(name, r.Fields, joined, field.DefaultFieldFlag)
r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
return joined[index], nil
}
var b bool
b, err = expression.EvalBoolExpr(ctx, r.On, r.evalArgs)
@ -280,6 +282,7 @@ func (r *JoinPlan) findMatchedRows(ctx context.Context, row *plan.Row, p plan.Pl
}
func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error) {
visitor := newJoinIdentVisitor(0)
for {
if r.curRow == nil {
r.curRow, err = r.Left.Next(ctx)
@ -292,7 +295,7 @@ func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error)
if r.On != nil {
tempExpr := r.On.Clone()
visitor := NewIdentEvalVisitor(r.Left.GetFields(), r.curRow.Data)
visitor.row = r.curRow.Data
_, err = tempExpr.Accept(visitor)
if err != nil {
return nil, errors.Trace(err)
@ -323,8 +326,8 @@ func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error)
joinedRow = append(append(joinedRow, r.curRow.Data...), rightRow.Data...)
if r.On != nil {
r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
return GetIdentValue(name, r.Fields, joinedRow, field.DefaultFieldFlag)
r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
return joinedRow[index], nil
}
b, err := expression.EvalBoolExpr(ctx, r.On, r.evalArgs)
@ -356,32 +359,34 @@ func (r *JoinPlan) Close() error {
return r.Left.Close()
}
// IdentEvalVisitor converts Ident expression to value expression.
type IdentEvalVisitor struct {
// joinIdentVisitor converts Ident expression to value expression in ON expression.
type joinIdentVisitor struct {
expression.BaseVisitor
fields []*field.ResultField
row []interface{}
offset int
}
// NewIdentEvalVisitor creates a new IdentEvalVisitor.
func NewIdentEvalVisitor(fields []*field.ResultField, row []interface{}) *IdentEvalVisitor {
iev := &IdentEvalVisitor{fields: fields, row: row}
// newJoinIdentVisitor creates a new joinIdentVisitor.
func newJoinIdentVisitor(offset int) *joinIdentVisitor {
iev := &joinIdentVisitor{offset: offset}
iev.BaseVisitor.V = iev
return iev
}
// VisitIdent implements Visitor interface.
func (iev *IdentEvalVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
v, err := GetIdentValue(i.L, iev.fields, iev.row, field.CheckFieldFlag)
if err != nil {
func (iev *joinIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
// the row here may be just left part or right part, but identifier may reference another part,
// so here we must check its reference index validation.
if i.ReferIndex < iev.offset || i.ReferIndex >= iev.offset+len(iev.row) {
return i, nil
}
v := iev.row[i.ReferIndex-iev.offset]
return expression.Value{Val: v}, nil
}
// VisitBinaryOperation swaps the right side identifier to left side if left side expression is static.
// So it can be used in index plan.
func (iev *IdentEvalVisitor) VisitBinaryOperation(binop *expression.BinaryOperation) (expression.Expression, error) {
func (iev *joinIdentVisitor) VisitBinaryOperation(binop *expression.BinaryOperation) (expression.Expression, error) {
var err error
binop.L, err = binop.L.Accept(iev)
if err != nil {

View File

@ -20,7 +20,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/plan/plans"
"github.com/pingcap/tidb/rset/rsets"

View File

@ -158,10 +158,11 @@ func (r *OrderByDefaultPlan) fetchAll(ctx context.Context) error {
if row == nil {
break
}
evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
v, err := GetIdentValue(name, r.ResultFields, row.Data, field.CheckFieldFlag)
if err == nil {
return v, nil
evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
if scope == expression.IdentReferFromTable {
return row.FromData[index], nil
} else if scope == expression.IdentReferSelectList {
return row.Data[index], nil
}
// try to find in outer query

View File

@ -48,7 +48,9 @@ func (t *testOrderBySuit) TestOrderBy(c *C) {
Src: tblPlan,
By: []expression.Expression{
&expression.Ident{
CIStr: model.NewCIStr("id"),
CIStr: model.NewCIStr("id"),
ReferScope: expression.IdentReferSelectList,
ReferIndex: 0,
},
},
Ascs: []bool{false},

View File

@ -22,7 +22,7 @@ import (
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/field"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/format"

View File

@ -199,5 +199,5 @@ func getIdentValueFromOuterQuery(ctx context.Context, name string) (interface{},
}
}
return nil, errors.Trace(err)
return nil, errors.Errorf("unknown field %s", name)
}

View File

@ -95,63 +95,43 @@ func (s *SelectList) GetFields() []*field.ResultField {
}
// UpdateAggFields adds aggregate function resultfield to select result field list.
func (s *SelectList) UpdateAggFields(expr expression.Expression, tableFields []*field.ResultField) (expression.Expression, error) {
// For aggregate function, the name can be in table or select list.
names := expression.MentionedColumns(expr)
for _, name := range names {
if field.ContainFieldName(name, tableFields, field.DefaultFieldFlag) {
continue
}
if field.ContainFieldName(name, s.ResultFields, field.DefaultFieldFlag) {
continue
}
return nil, errors.Errorf("Unknown column '%s'", name)
}
func (s *SelectList) UpdateAggFields(expr expression.Expression) (expression.Expression, error) {
// We must add aggregate function to hidden select list
// and use a position expression to fetch its value later.
exprName := expr.String()
idx := field.GetResultFieldIndex(exprName, s.ResultFields, field.CheckFieldFlag)
if len(idx) == 0 {
f := &field.Field{Expr: expr}
resultField := &field.ResultField{Name: exprName}
s.AddField(f, resultField)
name := strings.ToLower(expr.String())
index := -1
for i := 0; i < s.HiddenFieldOffset; i++ {
// only check origin name, e,g. "select sum(c1) as a from t order by sum(c1)"
// or "select sum(c1) from t order by sum(c1)"
if s.ResultFields[i].ColumnInfo.Name.L == name {
index = i
break
}
}
return &expression.Position{N: len(s.Fields), Name: exprName}, nil
if index == -1 {
f := &field.Field{Expr: expr}
s.AddField(f, nil)
pos := len(s.Fields)
s.AggFields[pos-1] = struct{}{}
return &expression.Position{N: pos, Name: name}, nil
}
// select list has this field, use it directly.
return &expression.Position{N: idx[0] + 1, Name: exprName}, nil
return &expression.Position{N: index + 1, Name: name}, nil
}
// CloneHiddenField checks and clones field and result field from table fields,
// and adds them to hidden field of select list.
func (s *SelectList) CloneHiddenField(name string, tableFields []*field.ResultField) bool {
// Check and add hidden field.
if field.ContainFieldName(name, tableFields, field.CheckFieldFlag) {
resultField, _ := field.CloneFieldByName(name, tableFields, field.CheckFieldFlag)
f := &field.Field{
Expr: &expression.Ident{
CIStr: resultField.ColumnInfo.Name,
},
}
s.AddField(f, resultField)
return true
}
return false
}
// CheckReferAmbiguous checks whether an identifier reference is ambiguous or not in select list.
// CheckAmbiguous checks whether an identifier reference is ambiguous or not in select list.
// e,g, "select c1 as a, c2 as a from t group by a" is ambiguous,
// but "select c1 as a, c1 as a from t group by a" is not.
// For MySQL "select c1 as a, c2 + 1 as a from t group by a" is not ambiguous too,
// so we will only check identifier too.
// If no ambiguous, -1 means expr refers none in select list, else an index in select list returns.
func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error) {
// "select c1 as a, c2 + 1 as a from t group by a" is not ambiguous too,
// If no ambiguous, -1 means expr refers none in select list, else an index for first match.
// CheckAmbiguous will break the check when finding first matching which is not an indentifier,
// or an index for an identifier field in the end, -1 means none found.
func (s *SelectList) CheckAmbiguous(expr expression.Expression) (int, error) {
if _, ok := expr.(*expression.Ident); !ok {
return -1, nil
}
@ -162,14 +142,22 @@ func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error
return -1, nil
}
// select c1 as a, 1 as a, c2 as a from t order by a is not ambiguous.
// select c1 as a, c2 as a from t order by a is ambiguous.
// select 1 as a, c1 as a from t order by a is not ambiguous.
// select c1 as a, sum(c1) as a from t group by a is error.
// select c1 as a, 1 as a, sum(c1) as a from t group by a is not error.
// so we will break the check if matching a none identifier field.
lastIndex := -1
// only check origin select list, no hidden field.
for i := 0; i < s.HiddenFieldOffset; i++ {
if !strings.EqualFold(s.ResultFields[i].Name, name) {
continue
} else if _, ok := s.Fields[i].Expr.(*expression.Ident); !ok {
// not identfier, no check
continue
}
if _, ok := s.Fields[i].Expr.(*expression.Ident); !ok {
// not identfier, return directly.
return i, nil
}
if lastIndex == -1 {
@ -179,6 +167,7 @@ func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error
}
// check origin name, e,g. "select c1 as c2, c2 from t group by c2" is ambiguous.
if s.ResultFields[i].ColumnInfo.Name.L != s.ResultFields[lastIndex].ColumnInfo.Name.L {
return -1, errors.Errorf("refer %s is ambiguous", expr)
}
@ -233,6 +222,7 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie
continue
}
// TODO: use fromIdentVisitor to cleanup.
var result *field.ResultField
for _, name := range names {
idx := field.GetResultFieldIndex(name, srcFields, field.DefaultFieldFlag)

View File

@ -96,7 +96,7 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) {
sl, err := plans.ResolveSelectList(fs, rs)
c.Assert(err, IsNil)
index, err := sl.CheckReferAmbiguous(&expression.Ident{
idx, err := sl.CheckAmbiguous(&expression.Ident{
CIStr: model.NewCIStr(t.Name),
})
@ -106,6 +106,6 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) {
}
c.Assert(err, IsNil)
c.Assert(t.Index, Equals, index)
c.Assert(t.Index, DeepEquals, idx)
}
}

View File

@ -25,7 +25,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"

View File

@ -22,7 +22,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/plan/plans"
"github.com/pingcap/tidb/rset/rsets"

View File

@ -18,7 +18,7 @@ import (
"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/plan/plans"
"github.com/pingcap/tidb/rset/rsets"

View File

@ -56,15 +56,15 @@ func (r *FilterDefaultPlan) Next(ctx context.Context) (row *plan.Row, err error)
if row == nil || err != nil {
return nil, errors.Trace(err)
}
r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
v, err0 := GetIdentValue(name, r.GetFields(), row.Data, field.DefaultFieldFlag)
if err0 == nil {
return v, nil
r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
if scope == expression.IdentReferFromTable {
return row.Data[index], nil
}
// try to find in outer query
return getIdentValueFromOuterQuery(ctx, name)
}
var meet bool
meet, err = r.meetCondition(ctx)
if err != nil {

View File

@ -44,7 +44,9 @@ func (t *testWhereSuit) TestWhere(c *C) {
Expr: &expression.BinaryOperation{
Op: opcode.GE,
L: &expression.Ident{
CIStr: model.NewCIStr("id"),
CIStr: model.NewCIStr("id"),
ReferScope: expression.IdentReferFromTable,
ReferIndex: 0,
},
R: expression.Value{
Val: 30,

View File

@ -39,10 +39,29 @@ type SelectFieldsRset struct {
SelectList *plans.SelectList
}
func updateSelectFieldsRefer(selectList *plans.SelectList) error {
visitor := newFromIdentVisitor(selectList.FromFields, fieldListClause)
// we only fix un-hidden fields here, for hidden fields, it should be
// handled in their own clause, in order by or having.
for i, f := range selectList.Fields[0:selectList.HiddenFieldOffset] {
e, err := f.Expr.Accept(visitor)
if err != nil {
return errors.Trace(err)
}
selectList.Fields[i].Expr = e
}
return nil
}
// Plan gets SrcPlan/SelectFieldsDefaultPlan.
// If all fields are equal to src plan fields, then gets SrcPlan.
// Default gets SelectFieldsDefaultPlan.
func (r *SelectFieldsRset) Plan(ctx context.Context) (plan.Plan, error) {
if err := updateSelectFieldsRefer(r.SelectList); err != nil {
return nil, errors.Trace(err)
}
fields := r.SelectList.Fields
srcFields := r.Src.GetFields()
if len(fields) == len(srcFields) {
@ -97,6 +116,7 @@ func (r *SelectFromDualRset) Plan(ctx context.Context) (plan.Plan, error) {
// field cannot contain identifier
for _, f := range r.Fields {
if cs := expression.MentionedColumns(f.Expr); len(cs) > 0 {
// TODO: check in outer query, like select * from t where t.c = (select c limit 1);
return nil, errors.Errorf("Unknown column '%s' in 'field list'", cs[0])
}
}

View File

@ -69,6 +69,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) {
oldFld := s.sr.SelectList.Fields[1]
s.sr.SelectList.Fields = s.sr.SelectList.Fields[:1]
s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields)
p, err = s.sr.Plan(nil)
c.Assert(err, IsNil)
@ -81,6 +82,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) {
s.sr.Src = tdp
s.sr.SelectList.Fields = []*field.Field{fld}
s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields)
p, err = s.sr.Plan(nil)
c.Assert(err, IsNil)
@ -94,6 +96,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) {
// cover isConst check, like `select 1, c1 from t`
s.sr.SelectList.Fields = []*field.Field{fld, oldFld}
s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields)
p, err = s.sr.Plan(nil)
c.Assert(err, IsNil)

View File

@ -39,70 +39,110 @@ type GroupByRset struct {
SelectList *plans.SelectList
}
type groupByVisitor struct {
expression.BaseVisitor
selectList *plans.SelectList
rootIdent *expression.Ident
}
func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
// Group by ambiguous rule:
// select c1 as a, c2 as a from t group by a is ambiguous
// select c1 as a, c2 as a from t group by a + 1 is ambiguous
// select c1 as c2, c2 from t group by c2 is ambiguous
// select c1 as c2, c2 from t group by c2 + 1 is not ambiguous
var (
index int
err error
)
if v.rootIdent == i {
// The group by is an identifier, we must check it first.
index, err = checkIdentAmbiguous(i, v.selectList, groupByClause)
if err != nil {
return nil, errors.Trace(err)
}
}
// first find this identifier in FROM.
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
if len(idx) > 0 {
i.ReferScope = expression.IdentReferFromTable
i.ReferIndex = idx[0]
return i, nil
}
if v.rootIdent != i {
// This identifier is the part of the group by, check ambiguous here.
index, err = checkIdentAmbiguous(i, v.selectList, groupByClause)
if err != nil {
return nil, errors.Trace(err)
}
}
// try to find in select list, we have got index using checkIdent before.
if index >= 0 {
// group by can not reference aggregate fields
if _, ok := v.selectList.AggFields[index]; ok {
return nil, errors.Errorf("Reference '%s' not supported (reference to group function)", i)
}
// find in select list
i.ReferScope = expression.IdentReferSelectList
i.ReferIndex = index
return i, nil
}
// TODO: check in out query
return i, errors.Errorf("Unknown column '%s' in 'group statement'", i)
}
func (v *groupByVisitor) VisitCall(c *expression.Call) (expression.Expression, error) {
ok := expression.IsAggregateFunc(c.F)
if ok {
return nil, errors.Errorf("group by cannot contain aggregate function %s", c)
}
var err error
for i, e := range c.Args {
c.Args[i], err = e.Accept(v)
if err != nil {
return nil, errors.Trace(err)
}
}
return c, nil
}
// Plan gets GroupByDefaultPlan.
func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) {
fields := r.SelectList.Fields
if err := updateSelectFieldsRefer(r.SelectList); err != nil {
return nil, errors.Trace(err)
}
r.SelectList.AggFields = GetAggFields(fields)
aggFields := r.SelectList.AggFields
visitor := &groupByVisitor{}
visitor.BaseVisitor.V = visitor
visitor.selectList = r.SelectList
for i, e := range r.By {
if v, ok := e.(expression.Value); ok {
var position int
switch u := v.Val.(type) {
case int64:
position = int(u)
case uint64:
position = int(u)
default:
continue
}
if position < 1 || position > len(fields) {
return nil, errors.Errorf("Unknown column '%d' in 'group statement'", position)
}
index := position - 1
if _, ok := aggFields[index]; ok {
return nil, errors.Errorf("Can't group on '%s'", fields[index])
}
// use Position expression for the associated field.
r.By[i] = &expression.Position{N: position}
} else {
index, err := r.SelectList.CheckReferAmbiguous(e)
if err != nil {
return nil, errors.Errorf("Column '%s' in group statement is ambiguous", e)
} else if _, ok := aggFields[index]; ok {
return nil, errors.Errorf("Can't group on '%s'", e)
}
// TODO: check more ambiguous case
// Group by ambiguous rule:
// select c1 as a, c2 as a from t group by a is ambiguous
// select c1 as a, c2 as a from t group by a + 1 is ambiguous
// select c1 as c2, c2 from t group by c2 is ambiguous
// select c1 as c2, c2 from t group by c2 + 1 is ambiguous
// TODO: use visitor to check aggregate function
names := expression.MentionedColumns(e)
for _, name := range names {
indices := field.GetFieldIndex(name, fields[0:r.SelectList.HiddenFieldOffset], field.DefaultFieldFlag)
if len(indices) == 1 {
// check reference to aggregate function, like `select c1, count(c1) as b from t group by b + 1`.
index := indices[0]
if _, ok := aggFields[index]; ok {
return nil, errors.Errorf("Reference '%s' not supported (reference to group function)", name)
}
}
}
// group by should be an expression, a qualified field name or a select field position,
// but can not contain any aggregate function.
if e := r.By[i]; expression.ContainAggregateFunc(e) {
return nil, errors.Errorf("group by cannot contain aggregate function %s", e.String())
}
pos, err := castPosition(e, r.SelectList, groupByClause)
if err != nil {
return nil, errors.Trace(err)
}
if pos != nil {
// use Position expression for the associated field.
r.By[i] = pos
continue
}
visitor.rootIdent = castIdent(e)
by, err := e.Accept(visitor)
if err != nil {
return nil, errors.Trace(err)
}
r.By[i] = by
}
return &plans.GroupByDefaultPlan{By: r.By, Src: r.Src,

View File

@ -50,6 +50,11 @@ func (s *testGroupByRsetSuite) SetUpSuite(c *C) {
s.r = &GroupByRset{Src: tblPlan, SelectList: selectList, By: by}
}
func resetAggFields(selectList *plans.SelectList) {
fields := selectList.Fields
selectList.AggFields = GetAggFields(fields)
}
func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) {
// `select id, name from t group by name`
p, err := s.r.Plan(nil)
@ -88,6 +93,8 @@ func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) {
fld := &field.Field{Expr: fldExpr, AsName: "a"}
s.r.SelectList.Fields[0] = fld
resetAggFields(s.r.SelectList)
s.r.By[0] = expression.Value{Val: int64(1)}
_, err = s.r.Plan(nil)
@ -115,6 +122,8 @@ func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) {
s.r.SelectList.ResultFields[0].Col.Name = model.NewCIStr("count(id)")
s.r.SelectList.ResultFields[0].Name = "a"
resetAggFields(s.r.SelectList)
_, err = s.r.Plan(nil)
c.Assert(err, NotNil)

View File

@ -31,66 +31,274 @@ type HavingRset struct {
Src plan.Plan
Expr expression.Expression
SelectList *plans.SelectList
// Group by clause
GroupBy []expression.Expression
}
// CheckAndUpdateSelectList checks having fields validity and set hidden fields to selectList.
func (r *HavingRset) CheckAndUpdateSelectList(selectList *plans.SelectList, groupBy []expression.Expression, tableFields []*field.ResultField) error {
// CheckAggregate will check whether order by has aggregate function or not,
// if has, we will add it to select list hidden field.
func (r *HavingRset) CheckAggregate(selectList *plans.SelectList) error {
if expression.ContainAggregateFunc(r.Expr) {
expr, err := selectList.UpdateAggFields(r.Expr, tableFields)
expr, err := selectList.UpdateAggFields(r.Expr)
if err != nil {
return errors.Errorf("%s in 'having clause'", err.Error())
}
r.Expr = expr
} else {
// having can only contain group by column and select list, e.g,
// `select c1 from t group by c2 having c3 > 0` is invalid,
// because c3 is not in group by and select list.
names := expression.MentionedColumns(r.Expr)
for _, name := range names {
found := false
// check name whether in select list.
// notice that `select c1 as c2 from t group by c1, c2, c3 having c2 > c3`,
// will use t.c2 not t.c1 here.
if field.ContainFieldName(name, selectList.ResultFields, field.OrgFieldNameFlag) {
continue
}
if field.ContainFieldName(name, selectList.ResultFields, field.FieldNameFlag) {
if field.ContainFieldName(name, tableFields, field.OrgFieldNameFlag) {
selectList.CloneHiddenField(name, tableFields)
}
continue
}
// check name whether in group by.
// group by must only have column name, e.g,
// `select c1 from t group by c2 having c2 > 0` is valid,
// but `select c1 from t group by c2 + 1 having c2 > 0` is invalid.
for _, by := range groupBy {
if !field.CheckFieldsEqual(name, by.String()) {
continue
}
// if name is not in table fields, it will get an unknown field error in GroupByRset,
// so no need to check return value.
selectList.CloneHiddenField(name, tableFields)
found = true
break
}
if !found {
return errors.Errorf("Unknown column '%s' in 'having clause'", name)
}
}
}
return nil
}
type havingVisitor struct {
expression.BaseVisitor
selectList *plans.SelectList
// for group by
groupBy []expression.Expression
// true means we are visiting aggregate function arguments now.
inAggregate bool
}
func (v *havingVisitor) visitIdentInAggregate(i *expression.Ident) (expression.Expression, error) {
// if we are visiting aggregate function arguments, the identifier first checks in from table,
// then in select list, and outer query finally.
// find this identifier in FROM.
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
if len(idx) > 0 {
i.ReferScope = expression.IdentReferFromTable
i.ReferIndex = idx[0]
return i, nil
}
// check in select list.
index, err := checkIdentAmbiguous(i, v.selectList, havingClause)
if err != nil {
return i, errors.Trace(err)
}
if index >= 0 {
// find in select list
i.ReferScope = expression.IdentReferSelectList
i.ReferIndex = index
return i, nil
}
// TODO: check in out query
// TODO: return unknown field error, but now just return directly.
// Because this may reference outer query.
return i, nil
}
func (v *havingVisitor) checkIdentInGroupBy(i *expression.Ident) (*expression.Ident, bool, error) {
for _, by := range v.groupBy {
e := castIdent(by)
if e == nil {
// group by must be a identifier too
continue
}
if !field.CheckFieldsEqual(e.L, i.L) {
// not same, continue
continue
}
// if group by references select list or having is qualified identifier,
// no other check.
if e.ReferScope == expression.IdentReferSelectList || field.IsQualifiedName(i.L) {
i.ReferScope = e.ReferScope
i.ReferIndex = e.ReferIndex
return i, true, nil
}
// having is unqualified name, e.g, select * from t1, t2 group by t1.c having c.
// both t1 and t2 have column c, we must check ambiguous here.
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
if len(idx) > 1 {
return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i)
}
i.ReferScope = e.ReferScope
i.ReferIndex = e.ReferIndex
return i, true, nil
}
return i, false, nil
}
func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression.Ident, bool, error) {
index, err := checkIdentAmbiguous(i, v.selectList, havingClause)
if err != nil {
return i, false, errors.Trace(err)
}
if index >= 0 {
// identifier references a select field. use it directly.
// e,g. select c1 as c2 from t having c2, here c2 references c1.
i.ReferScope = expression.IdentReferSelectList
i.ReferIndex = index
return i, true, nil
}
lastIndex := -1
var lastFieldName string
// we may meet this select c1 as c2 from t having c1, so we must check origin field name too.
for index := 0; index < v.selectList.HiddenFieldOffset; index++ {
e := castIdent(v.selectList.Fields[index].Expr)
if e == nil {
// not identifier
continue
}
if !field.CheckFieldsEqual(e.L, i.L) {
// not same, continue
continue
}
if field.IsQualifiedName(i.L) {
// qualified name, no need check
i.ReferScope = expression.IdentReferSelectList
i.ReferIndex = index
return i, true, nil
}
if lastIndex == -1 {
lastIndex = index
lastFieldName = e.L
continue
}
// we may meet select t1.c as a, t2.c as b from t1, t2 having c, must check ambiguous here
if !field.CheckFieldsEqual(lastFieldName, e.L) {
return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i)
}
i.ReferScope = expression.IdentReferSelectList
i.ReferIndex = index
return i, true, nil
}
if lastIndex != -1 {
i.ReferScope = expression.IdentReferSelectList
i.ReferIndex = lastIndex
return i, true, nil
}
return i, false, nil
}
func (v *havingVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
// Having has the most complex rule for ambiguous and identifer reference check.
// Having ambiguous rule:
// select c1 as a, c2 as a from t having a is ambiguous
// select c1 as a, c2 as a from t having a + 1 is ambiguous
// select c1 as c2, c2 from t having c2 is ambiguous
// select c1 as c2, c2 from t having c2 + 1 is ambiguous
// The identifier in having must exist in group by, select list or outer query, so select c1 from t having c2 is wrong,
// the identifier will first try to find the reference in group by, then in select list and finally in outer query.
// Having identifier reference check
// select c1 from t group by c2 having c2, c2 in having references group by c2.
// select c1 as c2 from t having c2, c2 in having references c1 in select field.
// select c1 as c2 from t group by c2 having c2, c2 in having references group by c2.
// select c1 as c2 from t having sum(c2), c2 in having sum(c2) references t.c2 in from table.
// select c1 as c2 from t having sum(c2) + c2, c2 in having sum(c2) references t.c2 in from table, but another c2 references c1 in select list.
// select c1 as c2 from t group by c2 having sum(c2) + c2, c2 in having sum(c2) references t.c2 in from table, another c2 references c2 in group by.
// select c1 as a from t having a, a references c1 in select field.
// select c1 as a from t having sum(a), a in order by sum(a) references c1 in select field.
// select c1 as c2 from t having c1, c1 in having references c1 in select list.
// TODO: unify VisitIdent for group by and order by visitor, they are little different.
if v.inAggregate {
return v.visitIdentInAggregate(i)
}
var (
err error
ok bool
)
// first, check in group by
i, ok, err = v.checkIdentInGroupBy(i)
if err != nil || ok {
return i, errors.Trace(err)
}
// then in select list
i, ok, err = v.checkIdentInSelectList(i)
if err != nil || ok {
return i, errors.Trace(err)
}
// TODO: check in out query
return i, errors.Errorf("Unknown column '%s' in 'having clause'", i)
}
func (v *havingVisitor) VisitPosition(p *expression.Position) (expression.Expression, error) {
n := p.N
// we have added order by to hidden field before, now visit it here.
expr := v.selectList.Fields[n-1].Expr
e, err := expr.Accept(v)
if err != nil {
return nil, errors.Trace(err)
}
v.selectList.Fields[n-1].Expr = e
return p, nil
}
func (v *havingVisitor) VisitCall(c *expression.Call) (expression.Expression, error) {
isAggregate := expression.IsAggregateFunc(c.F)
if v.inAggregate && isAggregate {
// aggregate function can't contain aggregate function
return nil, errors.Errorf("Invalid use of group function")
}
if isAggregate {
// set true to let outer know we are in aggregate function.
v.inAggregate = true
}
var err error
for i, e := range c.Args {
c.Args[i], err = e.Accept(v)
if err != nil {
return nil, errors.Trace(err)
}
}
if isAggregate {
v.inAggregate = false
}
return c, nil
}
// Plan gets HavingPlan.
func (r *HavingRset) Plan(ctx context.Context) (plan.Plan, error) {
visitor := &havingVisitor{
selectList: r.SelectList,
groupBy: r.GroupBy,
inAggregate: false,
}
visitor.BaseVisitor.V = visitor
e, err := r.Expr.Accept(visitor)
if err != nil {
return nil, errors.Trace(err)
}
r.Expr = e
return &plans.HavingPlan{Src: r.Src, Expr: r.Expr, SelectList: r.SelectList}, nil
}

View File

@ -50,59 +50,6 @@ func (s *testHavingRsetSuite) SetUpSuite(c *C) {
s.r = &HavingRset{Src: tblPlan, Expr: expr, SelectList: selectList}
}
func (s *testHavingRsetSuite) TestHavingRsetCheckAndUpdateSelectList(c *C) {
resultFields := s.r.Src.GetFields()
selectList := s.r.SelectList
groupBy := []expression.Expression{}
// `select id, name from t having id > 1`
err := s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
c.Assert(err, IsNil)
// `select name from t group by id having id > 1`
selectList.ResultFields = selectList.ResultFields[1:]
selectList.Fields = selectList.Fields[1:]
groupBy = []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("id")}}
err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
c.Assert(err, IsNil)
// `select name from t group by id + 1 having id > 1`
expr := expression.NewBinaryOperation(opcode.Plus, &expression.Ident{CIStr: model.NewCIStr("id")}, expression.Value{Val: 1})
groupBy = []expression.Expression{expr}
err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
c.Assert(err, IsNil)
// `select name from t group by id + 1 having count(1) > 1`
aggExpr, err := expression.NewCall("count", []expression.Expression{expression.Value{Val: 1}}, false)
c.Assert(err, IsNil)
s.r.Expr = aggExpr
err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
c.Assert(err, IsNil)
// `select name from t group by id + 1 having count(xxx) > 1`
aggExpr, err = expression.NewCall("count", []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("xxx")}}, false)
c.Assert(err, IsNil)
s.r.Expr = aggExpr
err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
c.Assert(err, NotNil)
// `select name from t group by id having xxx > 1`
expr = expression.NewBinaryOperation(opcode.GT, &expression.Ident{CIStr: model.NewCIStr("xxx")}, expression.Value{Val: 1})
s.r.Expr = expr
err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
c.Assert(err, NotNil)
}
func (s *testHavingRsetSuite) TestHavingRsetPlan(c *C) {
p, err := s.r.Plan(nil)
c.Assert(err, IsNil)

View File

@ -14,13 +14,15 @@
package rsets
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/plan/plans"
)
// GetAggFields gets aggregate fields position map.
func GetAggFields(fields []*field.Field) map[int]struct{} {
aggFields := map[int]struct{}{}
aggFields := make(map[int]struct{}, len(fields))
for i, v := range fields {
if expression.ContainAggregateFunc(v.Expr) {
aggFields[i] = struct{}{}
@ -34,3 +36,122 @@ func HasAggFields(fields []*field.Field) bool {
aggFields := GetAggFields(fields)
return len(aggFields) > 0
}
// castIdent returns an Ident expression if e is or nil.
func castIdent(e expression.Expression) *expression.Ident {
i, ok := e.(*expression.Ident)
if !ok {
return nil
}
return i
}
// TODO: export clause type and move to plan?
type clauseType int
const (
noneClause clauseType = iota
onClause
whereClause
groupByClause
fieldListClause
havingClause
orderByClause
)
func (clause clauseType) String() string {
switch clause {
case onClause:
return "on clause"
case fieldListClause:
return "field list"
case whereClause:
return "where clause"
case groupByClause:
return "group statement"
case orderByClause:
return "order clause"
case havingClause:
return "having clause"
}
return "none"
}
// castPosition returns an group/order by Position expression if e is a number.
func castPosition(e expression.Expression, selectList *plans.SelectList, clause clauseType) (*expression.Position, error) {
v, ok := e.(expression.Value)
if !ok {
return nil, nil
}
var position int
switch u := v.Val.(type) {
case int64:
position = int(u)
case uint64:
position = int(u)
default:
return nil, nil
}
if position < 1 || position > selectList.HiddenFieldOffset {
return nil, errors.Errorf("Unknown column '%d' in '%s'", position, clause)
}
if clause == groupByClause {
index := position - 1
if _, ok := selectList.AggFields[index]; ok {
return nil, errors.Errorf("Can't group on '%s'", selectList.Fields[index])
}
}
// use Position expression for the associated field.
return &expression.Position{N: position}, nil
}
func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clause clauseType) (int, error) {
index, err := selectList.CheckAmbiguous(i)
if err != nil {
return -1, errors.Errorf("Column '%s' in %s is ambiguous", i, clause)
} else if index == -1 {
return -1, nil
}
return index, nil
}
// fromIdentVisitor can only handle identifier which reference FROM table or outer query.
// like in common select list, where or join on condition.
type fromIdentVisitor struct {
expression.BaseVisitor
fromFields []*field.ResultField
clause clauseType
}
func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
idx := field.GetResultFieldIndex(i.L, v.fromFields, field.DefaultFieldFlag)
if len(idx) == 1 {
i.ReferScope = expression.IdentReferFromTable
i.ReferIndex = idx[0]
return i, nil
} else if len(idx) > 1 {
return nil, errors.Errorf("Column '%s' in %s is ambiguous", i, v.clause)
}
if v.clause == onClause {
// on clause can't check outer query.
return nil, errors.Errorf("Unknown column '%s' in '%s'", i, v.clause)
}
// TODO: check in outer query
return i, nil
}
func newFromIdentVisitor(fromFields []*field.ResultField, clause clauseType) *fromIdentVisitor {
visitor := &fromIdentVisitor{}
visitor.BaseVisitor.V = visitor
visitor.fromFields = fromFields
visitor.clause = clause
return visitor
}

View File

@ -141,6 +141,17 @@ func (r *JoinRset) buildJoinPlan(ctx context.Context, p *plans.JoinPlan, s *Join
p.Fields = append(p.Fields, rightFields...)
}
if p.On != nil {
visitor := newFromIdentVisitor(p.Fields, onClause)
e, err := p.On.Accept(visitor)
if err != nil {
return errors.Trace(err)
}
p.On = e
}
return nil
}

View File

@ -65,46 +65,110 @@ func (r *OrderByRset) String() string {
return strings.Join(a, ", ")
}
// CheckAndUpdateSelectList checks order by fields validity and set hidden fields to selectList.
func (r *OrderByRset) CheckAndUpdateSelectList(selectList *plans.SelectList, tableFields []*field.ResultField) error {
// CheckAggregate will check whether order by has aggregate function or not,
// if has, we will add it to select list hidden field.
func (r *OrderByRset) CheckAggregate(selectList *plans.SelectList) error {
for i, v := range r.By {
if expression.ContainAggregateFunc(v.Expr) {
expr, err := selectList.UpdateAggFields(v.Expr, tableFields)
expr, err := selectList.UpdateAggFields(v.Expr)
if err != nil {
return errors.Errorf("%s in 'order clause'", err.Error())
}
r.By[i].Expr = expr
} else {
if _, err := selectList.CheckReferAmbiguous(v.Expr); err != nil {
return errors.Errorf("Column '%s' in order statement is ambiguous", v.Expr)
}
}
}
return nil
}
// TODO: check more ambiguous case
// Order by ambiguous rule:
// select c1 as a, c2 as a from t order by a is ambiguous
// select c1 as a, c2 as a from t order by a + 1 is ambiguous
// select c1 as c2, c2 from t order by c2 is ambiguous
// select c1 as c2, c2 from t order by c2 + 1 is ambiguous
type orderByVisitor struct {
expression.BaseVisitor
selectList *plans.SelectList
rootIdent *expression.Ident
}
// TODO: use vistor to refactor all and combine following plan check.
names := expression.MentionedColumns(v.Expr)
for _, name := range names {
// try to find in select list
// TODO: mysql has confused result for this, see #555.
// now we use select list then order by, later we should make it easier.
if field.ContainFieldName(name, selectList.ResultFields, field.CheckFieldFlag) {
continue
}
func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
// Order by ambiguous rule:
// select c1 as a, c2 as a from t order by a is ambiguous
// select c1 as a, c2 as a from t order by a + 1 is ambiguous
// select c1 as c2, c2 from t order by c2 is ambiguous
// select c1 as c2, c2 from t order by c2 + 1 is not ambiguous
if !selectList.CloneHiddenField(name, tableFields) {
return errors.Errorf("Unknown column '%s' in 'order clause'", name)
}
}
// Order by identifier reference check
// select c1 as c2 from t order by c2, c2 in order by references c1 in select field.
// select c1 as c2 from t order by c2 + 1, c2 in order by c2 + 1 references t.c2 in from table.
// select c1 as c2 from t order by sum(c2), c2 in order by sum(c2) references t.c2 in from table.
// select c1 as c2 from t order by sum(1) + c2, c2 in order by sum(1) + c2 references t.c2 in from table.
// select c1 as a from t order by a, a references c1 in select field.
// select c1 as a from t order by sum(a), a in order by sum(a) references c1 in select field.
// TODO: unify VisitIdent for group by and order by visitor, they are little different.
var (
index int
err error
)
if v.rootIdent == i {
// The order by is an identifier, we must check it first.
index, err = checkIdentAmbiguous(i, v.selectList, orderByClause)
if err != nil {
return nil, errors.Trace(err)
}
if index >= 0 {
// identifier references a select field. use it directly.
// e,g. select c1 as c2 from t order by c2, here c2 references c1.
i.ReferScope = expression.IdentReferSelectList
i.ReferIndex = index
return i, nil
}
}
return nil
// find this identifier in FROM.
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
if len(idx) > 0 {
i.ReferScope = expression.IdentReferFromTable
i.ReferIndex = idx[0]
return i, nil
}
if v.rootIdent != i {
// This identifier is the part of the order by, check ambiguous here.
index, err = checkIdentAmbiguous(i, v.selectList, orderByClause)
if err != nil {
return nil, errors.Trace(err)
}
}
// try to find in select list, we have got index using checkIdentAmbiguous before.
if index >= 0 {
// find in select list
i.ReferScope = expression.IdentReferSelectList
i.ReferIndex = index
return i, nil
}
// TODO: check in out query
return i, errors.Errorf("Unknown column '%s' in 'order clause'", i)
}
func (v *orderByVisitor) VisitPosition(p *expression.Position) (expression.Expression, error) {
n := p.N
if p.N <= v.selectList.HiddenFieldOffset {
// this position expression is not in hidden field, no need to check
return p, nil
}
// we have added order by to hidden field before, now visit it here.
expr := v.selectList.Fields[n-1].Expr
e, err := expr.Accept(v)
if err != nil {
return nil, errors.Trace(err)
}
v.selectList.Fields[n-1].Expr = e
return p, nil
}
// Plan get SrcPlan/OrderByDefaultPlan.
@ -120,45 +184,28 @@ func (r *OrderByRset) Plan(ctx context.Context) (plan.Plan, error) {
ascs []bool
)
fields := r.Src.GetFields()
visitor := &orderByVisitor{}
visitor.BaseVisitor.V = visitor
visitor.selectList = r.SelectList
for i := range r.By {
e := r.By[i].Expr
if v, ok := e.(expression.Value); ok {
var (
position int
isPosition = true
)
switch u := v.Val.(type) {
case int64:
position = int(u)
case uint64:
position = int(u)
default:
isPosition = false
// only const value
}
pos, err := castPosition(e, r.SelectList, orderByClause)
if err != nil {
return nil, errors.Trace(err)
}
if isPosition {
if position < 1 || position > len(fields) {
return nil, errors.Errorf("Unknown column '%d' in 'order clause'", position)
}
// use Position expression for the associated field.
r.By[i].Expr = &expression.Position{N: position}
}
if pos != nil {
// use Position expression for the associated field.
r.By[i].Expr = pos
} else {
// Don't check ambiguous here, only check field exists or not.
// TODO: use visitor to refactor.
colNames := expression.MentionedColumns(e)
for _, name := range colNames {
if idx := field.GetResultFieldIndex(name, r.SelectList.ResultFields, field.DefaultFieldFlag); len(idx) == 0 {
// find in from
if idx = field.GetResultFieldIndex(name, r.SelectList.FromFields, field.DefaultFieldFlag); len(idx) == 0 {
return nil, errors.Errorf("unknown field %s", name)
}
}
visitor.rootIdent = castIdent(e)
e, err = e.Accept(visitor)
if err != nil {
return nil, errors.Trace(err)
}
r.By[i].Expr = e
}
by = append(by, r.By[i].Expr)

View File

@ -47,64 +47,6 @@ func (s *testOrderByRsetSuite) SetUpSuite(c *C) {
s.r = &OrderByRset{Src: tblPlan, SelectList: selectList}
}
func (s *testOrderByRsetSuite) TestOrderByRsetCheckAndUpdateSelectList(c *C) {
resultFields := s.r.Src.GetFields()
fields := make([]*field.Field, len(resultFields))
for i, resultField := range resultFields {
name := resultField.Name
fields[i] = &field.Field{Expr: &expression.Ident{CIStr: model.NewCIStr(name)}}
}
selectList := &plans.SelectList{
HiddenFieldOffset: len(resultFields),
ResultFields: resultFields,
Fields: fields,
}
expr := &expression.Ident{CIStr: model.NewCIStr("id")}
orderByItem := OrderByItem{Expr: expr, Asc: true}
by := []OrderByItem{orderByItem}
r := &OrderByRset{By: by, SelectList: selectList}
// `select id, name from t order by id`
err := r.CheckAndUpdateSelectList(selectList, resultFields)
c.Assert(err, IsNil)
// `select id, name as id from t order by id`
selectList.Fields[1].AsName = "id"
selectList.ResultFields[1].Name = "id"
err = r.CheckAndUpdateSelectList(selectList, resultFields)
c.Assert(err, NotNil)
// `select id, name from t order by count(1) > 1`
aggExpr, err := expression.NewCall("count", []expression.Expression{expression.Value{Val: 1}}, false)
c.Assert(err, IsNil)
r.By[0].Expr = aggExpr
err = r.CheckAndUpdateSelectList(selectList, resultFields)
c.Assert(err, IsNil)
// `select id, name from t order by count(xxx) > 1`
aggExpr, err = expression.NewCall("count", []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("xxx")}}, false)
c.Assert(err, IsNil)
r.By[0].Expr = aggExpr
err = r.CheckAndUpdateSelectList(selectList, resultFields)
c.Assert(err, NotNil)
// `select id, name from t order by xxx`
r.By[0].Expr = &expression.Ident{CIStr: model.NewCIStr("xxx")}
selectList.Fields[1].AsName = "name"
selectList.ResultFields[1].Name = "name"
err = r.CheckAndUpdateSelectList(selectList, resultFields)
c.Assert(err, NotNil)
}
func (s *testOrderByRsetSuite) TestOrderByRsetPlan(c *C) {
// `select id, name from t`
_, err := s.r.Plan(nil)
@ -145,12 +87,6 @@ func (s *testOrderByRsetSuite) TestOrderByRsetPlan(c *C) {
_, err = s.r.Plan(nil)
c.Assert(err, IsNil)
// `select id, name from t order by xxx`
s.r.By[0].Expr = &expression.Ident{CIStr: model.NewCIStr("xxx")}
_, err = s.r.Plan(nil)
c.Assert(err, NotNil)
// check src plan is NullPlan
s.r.Src = &plans.NullPlan{Fields: s.r.SelectList.ResultFields}

View File

@ -162,7 +162,6 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl
if err != nil {
return nil, err
}
if val == nil {
// like `select * from t where null`.
return &plans.NullPlan{Fields: r.Src.GetFields()}, nil
@ -172,7 +171,6 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl
if err != nil {
return nil, err
}
if n == 0 {
// like `select * from t where 0`.
return &plans.NullPlan{Fields: r.Src.GetFields()}, nil
@ -181,8 +179,25 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl
return &plans.FilterDefaultPlan{Plan: r.Src, Expr: e}, nil
}
func (r *WhereRset) updateWhereFieldsRefer() error {
visitor := newFromIdentVisitor(r.Src.GetFields(), whereClause)
e, err := r.Expr.Accept(visitor)
if err != nil {
return errors.Trace(err)
}
r.Expr = e
return nil
}
// Plan gets NullPlan/FilterDefaultPlan.
func (r *WhereRset) Plan(ctx context.Context) (plan.Plan, error) {
// Update where fields refer expr by using fromIdentVisitor.
if err := r.updateWhereFieldsRefer(); err != nil {
return nil, errors.Trace(err)
}
expr := r.Expr.Clone()
if expr.IsStatic() {
// IsStaic means we have a const value for where condition, and we don't need any index.

View File

@ -31,7 +31,7 @@ import (
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/kv"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/rset"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/db"

View File

@ -15,7 +15,7 @@ package variable
import (
"github.com/pingcap/tidb/context"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/mysql"
)
// SessionVars is to handle user-defined or global variables in current session.

Some files were not shown because too many files have changed in this diff Show More