*: merge master
This commit is contained in:
@ -22,7 +22,7 @@ import (
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/ngaut/log"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/errors"
|
||||
"github.com/pingcap/tidb/util/errors2"
|
||||
)
|
||||
|
||||
@ -24,7 +24,7 @@ import (
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ import (
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ import (
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/meta"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/coldef"
|
||||
"github.com/pingcap/tidb/table"
|
||||
"github.com/pingcap/tidb/table/tables"
|
||||
@ -512,7 +512,7 @@ func updateOldRows(ctx context.Context, t *tables.Table, col *column.Col) error
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
it, err := txn.Seek([]byte(t.FirstKey()), nil)
|
||||
it, err := txn.Seek([]byte(t.FirstKey()))
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -685,7 +685,7 @@ func (d *ddl) buildIndex(ctx context.Context, t table.Table, idxInfo *model.Inde
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
it, err := txn.Seek([]byte(firstKey), nil)
|
||||
it, err := txn.Seek([]byte(firstKey))
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -23,7 +23,7 @@ import (
|
||||
"github.com/pingcap/tidb/ddl"
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
"github.com/pingcap/tidb/stmt"
|
||||
|
||||
@ -31,7 +31,7 @@ import (
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/rset"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
qerror "github.com/pingcap/tidb/util/errors"
|
||||
|
||||
@ -23,7 +23,7 @@ import (
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/context"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
@ -20,7 +20,7 @@ import (
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/expression/builtin"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
@ -186,7 +186,7 @@ func (s *testBinOpSuite) TestIdentRelOp(c *C) {
|
||||
|
||||
f := func(name string) *Ident {
|
||||
return &Ident{
|
||||
model.NewCIStr(name),
|
||||
CIStr: model.NewCIStr(name),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ import (
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/kv/memkv"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ package builtin
|
||||
|
||||
import (
|
||||
. "github.com/pingcap/check"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
func (s *testBuiltinSuite) TestLength(c *C) {
|
||||
|
||||
@ -21,7 +21,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/juju/errors"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
func (s *testBuiltinSuite) TestDate(c *C) {
|
||||
|
||||
@ -18,7 +18,7 @@ import (
|
||||
|
||||
"github.com/pingcap/tidb/context"
|
||||
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ import (
|
||||
"errors"
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
104
expression/date_add.go
Normal file
104
expression/date_add.go
Normal file
@ -0,0 +1,104 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package expression
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
// DateAdd is for time date_add function.
|
||||
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
|
||||
type DateAdd struct {
|
||||
Unit string
|
||||
Date Expression
|
||||
Interval Expression
|
||||
}
|
||||
|
||||
// Clone implements the Expression Clone interface.
|
||||
func (da *DateAdd) Clone() Expression {
|
||||
n := *da
|
||||
return &n
|
||||
}
|
||||
|
||||
// Eval implements the Expression Eval interface.
|
||||
func (da *DateAdd) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
|
||||
dv, err := da.Date.Eval(ctx, args)
|
||||
if dv == nil || err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
sv, err := types.ToString(dv)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
f := types.NewFieldType(mysql.TypeDatetime)
|
||||
f.Decimal = mysql.MaxFsp
|
||||
|
||||
dv, err = types.Convert(sv, f)
|
||||
if dv == nil || err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
t, ok := dv.(mysql.Time)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("need time type, but got %T", dv)
|
||||
}
|
||||
|
||||
iv, err := da.Interval.Eval(ctx, args)
|
||||
if iv == nil || err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
format, err := types.ToString(iv)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
years, months, days, durations, err := mysql.ExtractTimeValue(da.Unit, strings.TrimSpace(format))
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
t.Time = t.Time.Add(durations)
|
||||
t.Time = t.Time.AddDate(int(years), int(months), int(days))
|
||||
|
||||
// "2011-11-11 10:10:20.000000" outputs "2011-11-11 10:10:20".
|
||||
if t.Time.Nanosecond() == 0 {
|
||||
t.Fsp = 0
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// IsStatic implements the Expression IsStatic interface.
|
||||
func (da *DateAdd) IsStatic() bool {
|
||||
return da.Date.IsStatic() && da.Interval.IsStatic()
|
||||
}
|
||||
|
||||
// String implements the Expression String interface.
|
||||
func (da *DateAdd) String() string {
|
||||
return fmt.Sprintf("DATE_ADD(%s, INTERVAL %s %s)", da.Date, da.Interval, strings.ToUpper(da.Unit))
|
||||
}
|
||||
|
||||
// Accept implements the Visitor Accept interface.
|
||||
func (da *DateAdd) Accept(v Visitor) (Expression, error) {
|
||||
return v.VisitDateAdd(da)
|
||||
}
|
||||
144
expression/date_add_test.go
Normal file
144
expression/date_add_test.go
Normal file
@ -0,0 +1,144 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package expression
|
||||
|
||||
import (
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
var _ = Suite(&testDateAddSuite{})
|
||||
|
||||
type testDateAddSuite struct {
|
||||
}
|
||||
|
||||
func (t *testDateAddSuite) TestDateAdd(c *C) {
|
||||
input := "2011-11-11 10:10:10"
|
||||
e := &DateAdd{
|
||||
Unit: "DAY",
|
||||
Date: Value{Val: input},
|
||||
Interval: Value{Val: "1"},
|
||||
}
|
||||
c.Assert(e.String(), Equals, `DATE_ADD("2011-11-11 10:10:10", INTERVAL "1" DAY)`)
|
||||
c.Assert(e.Clone(), NotNil)
|
||||
c.Assert(e.IsStatic(), IsTrue)
|
||||
|
||||
_, err := e.Eval(nil, nil)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// Test null.
|
||||
e = &DateAdd{
|
||||
Unit: "DAY",
|
||||
Date: Value{Val: nil},
|
||||
Interval: Value{Val: "1"},
|
||||
}
|
||||
|
||||
v, err := e.Eval(nil, nil)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(v, IsNil)
|
||||
|
||||
e = &DateAdd{
|
||||
Unit: "DAY",
|
||||
Date: Value{Val: input},
|
||||
Interval: Value{Val: nil},
|
||||
}
|
||||
|
||||
v, err = e.Eval(nil, nil)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(v, IsNil)
|
||||
|
||||
// Test eval.
|
||||
tbl := []struct {
|
||||
Unit string
|
||||
Interval interface{}
|
||||
Expect string
|
||||
}{
|
||||
{"MICROSECOND", "1000", "2011-11-11 10:10:10.001000"},
|
||||
{"MICROSECOND", 1000, "2011-11-11 10:10:10.001000"},
|
||||
{"SECOND", "10", "2011-11-11 10:10:20"},
|
||||
{"MINUTE", "10", "2011-11-11 10:20:10"},
|
||||
{"HOUR", "10", "2011-11-11 20:10:10"},
|
||||
{"DAY", "11", "2011-11-22 10:10:10"},
|
||||
{"WEEK", "2", "2011-11-25 10:10:10"},
|
||||
{"MONTH", "2", "2012-01-11 10:10:10"},
|
||||
{"QUARTER", "4", "2012-11-11 10:10:10"},
|
||||
{"YEAR", "2", "2013-11-11 10:10:10"},
|
||||
{"SECOND_MICROSECOND", "10.00100000", "2011-11-11 10:10:20.100000"},
|
||||
{"SECOND_MICROSECOND", "10.0010000000", "2011-11-11 10:10:30"},
|
||||
{"SECOND_MICROSECOND", "10.0010000010", "2011-11-11 10:10:30.000010"},
|
||||
{"MINUTE_MICROSECOND", "10:10.100", "2011-11-11 10:20:20.100000"},
|
||||
{"MINUTE_SECOND", "10:10", "2011-11-11 10:20:20"},
|
||||
{"HOUR_MICROSECOND", "10:10:10.100", "2011-11-11 20:20:20.100000"},
|
||||
{"HOUR_SECOND", "10:10:10", "2011-11-11 20:20:20"},
|
||||
{"HOUR_MINUTE", "10:10", "2011-11-11 20:20:10"},
|
||||
{"DAY_MICROSECOND", "11 10:10:10.100", "2011-11-22 20:20:20.100000"},
|
||||
{"DAY_SECOND", "11 10:10:10", "2011-11-22 20:20:20"},
|
||||
{"DAY_MINUTE", "11 10:10", "2011-11-22 20:20:10"},
|
||||
{"DAY_HOUR", "11 10", "2011-11-22 20:10:10"},
|
||||
{"YEAR_MONTH", "11-1", "2022-12-11 10:10:10"},
|
||||
{"YEAR_MONTH", "11-11", "2023-10-11 10:10:10"},
|
||||
}
|
||||
|
||||
for _, t := range tbl {
|
||||
e := &DateAdd{
|
||||
Unit: t.Unit,
|
||||
Date: Value{Val: input},
|
||||
Interval: Value{Val: t.Interval},
|
||||
}
|
||||
|
||||
v, err := e.Eval(nil, nil)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
value, ok := v.(mysql.Time)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(value.String(), Equals, t.Expect)
|
||||
}
|
||||
|
||||
// Test error.
|
||||
errInput := "20111111 10:10:10"
|
||||
errTbl := []struct {
|
||||
Unit string
|
||||
Interval interface{}
|
||||
}{
|
||||
{"MICROSECOND", "abc1000"},
|
||||
{"MICROSECOND", ""},
|
||||
{"SECOND_MICROSECOND", "10"},
|
||||
{"MINUTE_MICROSECOND", "10.0000"},
|
||||
{"MINUTE_MICROSECOND", "10:10:10.0000"},
|
||||
|
||||
// MySQL support, but tidb not.
|
||||
{"HOUR_MICROSECOND", "10:10.0000"},
|
||||
{"YEAR_MONTH", "10 1"},
|
||||
}
|
||||
|
||||
for _, t := range errTbl {
|
||||
e := &DateAdd{
|
||||
Unit: t.Unit,
|
||||
Date: Value{Val: input},
|
||||
Interval: Value{Val: t.Interval},
|
||||
}
|
||||
|
||||
_, err := e.Eval(nil, nil)
|
||||
c.Assert(err, NotNil)
|
||||
|
||||
e = &DateAdd{
|
||||
Unit: t.Unit,
|
||||
Date: Value{Val: errInput},
|
||||
Interval: Value{Val: t.Interval},
|
||||
}
|
||||
|
||||
v, err := e.Eval(nil, nil)
|
||||
c.Assert(err, NotNil, Commentf("%s", v))
|
||||
}
|
||||
}
|
||||
@ -19,12 +19,12 @@ import (
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/context"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
// Extract is for time extract function.
|
||||
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract
|
||||
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract
|
||||
type Extract struct {
|
||||
Unit string
|
||||
Date Expression
|
||||
@ -56,7 +56,7 @@ func (e *Extract) Eval(ctx context.Context, args map[interface{}]interface{}) (i
|
||||
return nil, errors.Errorf("need time type, but got %T", v)
|
||||
}
|
||||
|
||||
n, err1 := extractTime(e.Unit, t)
|
||||
n, err1 := mysql.ExtractTimeNum(e.Unit, t)
|
||||
if err1 != nil {
|
||||
return nil, errors.Trace(err1)
|
||||
}
|
||||
@ -78,70 +78,3 @@ func (e *Extract) String() string {
|
||||
func (e *Extract) Accept(v Visitor) (Expression, error) {
|
||||
return v.VisitExtract(e)
|
||||
}
|
||||
|
||||
func extractTime(unit string, t mysql.Time) (int64, error) {
|
||||
switch strings.ToUpper(unit) {
|
||||
case "MICROSECOND":
|
||||
return int64(t.Nanosecond() / 1000), nil
|
||||
case "SECOND":
|
||||
return int64(t.Second()), nil
|
||||
case "MINUTE":
|
||||
return int64(t.Minute()), nil
|
||||
case "HOUR":
|
||||
return int64(t.Hour()), nil
|
||||
case "DAY":
|
||||
return int64(t.Day()), nil
|
||||
case "WEEK":
|
||||
_, week := t.ISOWeek()
|
||||
return int64(week), nil
|
||||
case "MONTH":
|
||||
return int64(t.Month()), nil
|
||||
case "QUARTER":
|
||||
m := int64(t.Month())
|
||||
// 1 - 3 -> 1
|
||||
// 4 - 6 -> 2
|
||||
// 7 - 9 -> 3
|
||||
// 10 - 12 -> 4
|
||||
return (m + 2) / 3, nil
|
||||
case "YEAR":
|
||||
return int64(t.Year()), nil
|
||||
case "SECOND_MICROSECOND":
|
||||
return int64(t.Second())*1000000 + int64(t.Nanosecond())/1000, nil
|
||||
case "MINUTE_MICROSECOND":
|
||||
_, m, s := t.Clock()
|
||||
return int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil
|
||||
case "MINUTE_SECOND":
|
||||
_, m, s := t.Clock()
|
||||
return int64(m*100 + s), nil
|
||||
case "HOUR_MICROSECOND":
|
||||
h, m, s := t.Clock()
|
||||
return int64(h)*10000000000 + int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil
|
||||
case "HOUR_SECOND":
|
||||
h, m, s := t.Clock()
|
||||
return int64(h)*10000 + int64(m)*100 + int64(s), nil
|
||||
case "HOUR_MINUTE":
|
||||
h, m, _ := t.Clock()
|
||||
return int64(h)*100 + int64(m), nil
|
||||
case "DAY_MICROSECOND":
|
||||
h, m, s := t.Clock()
|
||||
d := t.Day()
|
||||
return int64(d*1000000+h*10000+m*100+s)*1000000 + int64(t.Nanosecond())/1000, nil
|
||||
case "DAY_SECOND":
|
||||
h, m, s := t.Clock()
|
||||
d := t.Day()
|
||||
return int64(d)*1000000 + int64(h)*10000 + int64(m)*100 + int64(s), nil
|
||||
case "DAY_MINUTE":
|
||||
h, m, _ := t.Clock()
|
||||
d := t.Day()
|
||||
return int64(d)*10000 + int64(h)*100 + int64(m), nil
|
||||
case "DAY_HOUR":
|
||||
h, _, _ := t.Clock()
|
||||
d := t.Day()
|
||||
return int64(d)*100 + int64(h), nil
|
||||
case "YEAR_MONTH":
|
||||
y, m, _ := t.Date()
|
||||
return int64(y)*100 + int64(m), nil
|
||||
default:
|
||||
return 0, errors.Errorf("invalid unit %s", unit)
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,7 +30,7 @@ import (
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/expression/builtin"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
@ -45,13 +45,15 @@ const (
|
||||
ExprEvalPositionFunc = "$positionFunc"
|
||||
// ExprEvalValuesFunc is the key saving a function to retrieve value for column name.
|
||||
ExprEvalValuesFunc = "$valuesFunc"
|
||||
// ExprEvalIdentReferFunc is the key saving a function to retrieve value with identifier reference index.
|
||||
ExprEvalIdentReferFunc = "$identReferFunc"
|
||||
)
|
||||
|
||||
var (
|
||||
// CurrentTimestamp is the keyword getting default value for datetime and timestamp type.
|
||||
CurrentTimestamp = "CURRENT_TIMESTAMP"
|
||||
// CurrentTimeExpr is the expression retireving default value for datetime and timestamp type.
|
||||
CurrentTimeExpr = &Ident{model.NewCIStr(CurrentTimestamp)}
|
||||
CurrentTimeExpr = &Ident{CIStr: model.NewCIStr(CurrentTimestamp)}
|
||||
// ZeroTimestamp shows the zero datetime and timestamp.
|
||||
ZeroTimestamp = "0000-00-00 00:00:00"
|
||||
)
|
||||
@ -142,22 +144,40 @@ func newMentionedAggregateFuncsVisitor() *mentionedAggregateFuncsVisitor {
|
||||
}
|
||||
|
||||
func (v *mentionedAggregateFuncsVisitor) VisitCall(c *Call) (Expression, error) {
|
||||
isAggregate := IsAggregateFunc(c.F)
|
||||
|
||||
if isAggregate {
|
||||
v.exprs = append(v.exprs, c)
|
||||
}
|
||||
|
||||
n := len(v.exprs)
|
||||
for _, e := range c.Args {
|
||||
_, err := e.Accept(v)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
f, ok := builtin.Funcs[strings.ToLower(c.F)]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown function %s", c.F)
|
||||
}
|
||||
if f.IsAggregate {
|
||||
v.exprs = append(v.exprs, c)
|
||||
|
||||
if isAggregate && len(v.exprs) != n {
|
||||
// aggregate function can't use aggregate function as the arg.
|
||||
// here means we have aggregate function in arg.
|
||||
return nil, errors.Errorf("Invalid use of group function")
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// IsAggregateFunc checks whether name is an aggregate function or not.
|
||||
func IsAggregateFunc(name string) bool {
|
||||
// TODO: use switch defined aggregate name "sum", "count", etc... directly.
|
||||
// Maybe we can remove builtin IsAggregate field later.
|
||||
f, ok := builtin.Funcs[strings.ToLower(name)]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return f.IsAggregate
|
||||
}
|
||||
|
||||
// MentionedColumns returns a list of names for Ident expression.
|
||||
func MentionedColumns(e Expression) []string {
|
||||
var names []string
|
||||
|
||||
@ -5,9 +5,8 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
)
|
||||
@ -47,6 +46,15 @@ func (s *testHelperSuite) TestContainAggFunc(c *C) {
|
||||
b := ContainAggregateFunc(t.Expr)
|
||||
c.Assert(b, Equals, t.Expect)
|
||||
}
|
||||
|
||||
expr := &Call{
|
||||
F: "count",
|
||||
Args: []Expression{
|
||||
&Call{F: "count", Args: []Expression{v}},
|
||||
},
|
||||
}
|
||||
_, err := MentionedAggregateFuncs(expr)
|
||||
c.Assert(err, NotNil)
|
||||
}
|
||||
|
||||
func (s *testHelperSuite) TestMentionedColumns(c *C) {
|
||||
@ -57,7 +65,7 @@ func (s *testHelperSuite) TestMentionedColumns(c *C) {
|
||||
}{
|
||||
{Value{1}, 0},
|
||||
{&BinaryOperation{L: v, R: v}, 0},
|
||||
{&Ident{model.NewCIStr("id")}, 1},
|
||||
{&Ident{CIStr: model.NewCIStr("id")}, 1},
|
||||
{&Call{F: "count", Args: []Expression{v}}, 0},
|
||||
{&IsNull{Expr: v}, 0},
|
||||
{&PExpr{Expr: v}, 0},
|
||||
@ -106,7 +114,7 @@ func (s *testHelperSuite) TestBase(c *C) {
|
||||
{int64(1), int64(1)},
|
||||
{&UnaryOperation{Op: opcode.Plus, V: Value{1}}, 1},
|
||||
{&UnaryOperation{Op: opcode.Not, V: Value{1}}, nil},
|
||||
{&UnaryOperation{Op: opcode.Plus, V: &Ident{model.NewCIStr("id")}}, nil},
|
||||
{&UnaryOperation{Op: opcode.Plus, V: &Ident{CIStr: model.NewCIStr("id")}}, nil},
|
||||
{nil, nil},
|
||||
}
|
||||
|
||||
@ -228,7 +236,7 @@ func (s *testHelperSuite) TestGetTimeValue(c *C) {
|
||||
{Value{"2012-13-12 00:00:00"}},
|
||||
{Value{0}},
|
||||
{Value{int64(1)}},
|
||||
{&Ident{model.NewCIStr("xxx")}},
|
||||
{&Ident{CIStr: model.NewCIStr("xxx")}},
|
||||
{NewUnaryOperation(opcode.Minus, Value{int64(1)})},
|
||||
}
|
||||
|
||||
|
||||
@ -29,10 +29,23 @@ var (
|
||||
_ Expression = (*Ident)(nil)
|
||||
)
|
||||
|
||||
const (
|
||||
// IdentReferSelectList means the identifier reference is in select list.
|
||||
IdentReferSelectList = 1
|
||||
// IdentReferFromTable means the identifier reference is in FROM table.
|
||||
IdentReferFromTable = 2
|
||||
)
|
||||
|
||||
// Ident is the identifier expression.
|
||||
type Ident struct {
|
||||
// model.CIStr contains origin identifier name and its lowercase name.
|
||||
model.CIStr
|
||||
|
||||
// ReferScope means where the identifer reference is, select list or from.
|
||||
ReferScope int
|
||||
|
||||
// ReferIndex is the index to get the identifer data.
|
||||
ReferIndex int
|
||||
}
|
||||
|
||||
// Clone implements the Expression Clone interface.
|
||||
@ -63,6 +76,14 @@ func (i *Ident) Eval(ctx context.Context, args map[interface{}]interface{}) (v i
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TODO: we will unify ExprEvalIdentReferFunc and ExprEvalIdentFunc later,
|
||||
// now just put them here for refactor step by step.
|
||||
if f, ok := args[ExprEvalIdentReferFunc]; ok {
|
||||
if got, ok := f.(func(string, int, int) (interface{}, error)); ok {
|
||||
return got(i.L, i.ReferScope, i.ReferIndex)
|
||||
}
|
||||
}
|
||||
|
||||
if f, ok := args[ExprEvalIdentFunc]; ok {
|
||||
if got, ok := f.(func(string) (interface{}, error)); ok {
|
||||
return got(i.L)
|
||||
|
||||
@ -26,7 +26,7 @@ type testIdentSuite struct {
|
||||
|
||||
func (s *testIdentSuite) TestIdent(c *C) {
|
||||
e := Ident{
|
||||
model.NewCIStr("id"),
|
||||
CIStr: model.NewCIStr("id"),
|
||||
}
|
||||
|
||||
c.Assert(e.IsStatic(), IsFalse)
|
||||
@ -55,4 +55,14 @@ func (s *testIdentSuite) TestIdent(c *C) {
|
||||
v, err = e.Eval(nil, m)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(v, Equals, 1)
|
||||
|
||||
delete(m, ExprEvalIdentFunc)
|
||||
e.ReferScope = IdentReferSelectList
|
||||
e.ReferIndex = 1
|
||||
m[ExprEvalIdentReferFunc] = func(string, int, int) (interface{}, error) {
|
||||
return 2, nil
|
||||
}
|
||||
v, err = e.Eval(nil, m)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(v, Equals, 2)
|
||||
}
|
||||
|
||||
@ -23,7 +23,7 @@ import (
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/context"
|
||||
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
@ -19,7 +19,7 @@ import (
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
@ -106,6 +106,9 @@ type Visitor interface {
|
||||
|
||||
// VisitFunctionTrim visits FunctionTrim expression.
|
||||
VisitFunctionTrim(v *FunctionTrim) (Expression, error)
|
||||
|
||||
// VisitDateAdd visits DateAdd expression.
|
||||
VisitDateAdd(da *DateAdd) (Expression, error)
|
||||
}
|
||||
|
||||
// BaseVisitor is the base implementation of Visitor.
|
||||
@ -456,7 +459,7 @@ func (bv *BaseVisitor) VisitWhenClause(w *WhenClause) (Expression, error) {
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// VisitExtract implements Visitor
|
||||
// VisitExtract implements Visitor interface.
|
||||
func (bv *BaseVisitor) VisitExtract(v *Extract) (Expression, error) {
|
||||
var err error
|
||||
v.Date, err = v.Date.Accept(bv.V)
|
||||
@ -478,3 +481,19 @@ func (bv *BaseVisitor) VisitFunctionTrim(ss *FunctionTrim) (Expression, error) {
|
||||
}
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// VisitDateAdd implements Visitor interface.
|
||||
func (bv *BaseVisitor) VisitDateAdd(da *DateAdd) (Expression, error) {
|
||||
var err error
|
||||
da.Date, err = da.Date.Accept(bv.V)
|
||||
if err != nil {
|
||||
return da, errors.Trace(err)
|
||||
}
|
||||
|
||||
da.Interval, err = da.Interval.Accept(bv.V)
|
||||
if err != nil {
|
||||
return da, errors.Trace(err)
|
||||
}
|
||||
|
||||
return da, nil
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@ import (
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
@ -19,7 +19,7 @@ import (
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/field"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ import (
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/column"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@ -19,7 +19,7 @@ import (
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ import (
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/infoschema"
|
||||
"github.com/pingcap/tidb/model"
|
||||
"github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/store/localstore"
|
||||
"github.com/pingcap/tidb/store/localstore/goleveldb"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
@ -51,7 +51,7 @@ func (*testSuite) TestT(c *C) {
|
||||
ID: 3,
|
||||
Name: colName,
|
||||
Offset: 0,
|
||||
FieldType: *types.NewFieldType(mysqldef.TypeLonglong),
|
||||
FieldType: *types.NewFieldType(mysql.TypeLonglong),
|
||||
}
|
||||
|
||||
idxInfo := &model.IndexInfo{
|
||||
|
||||
40
kv/compactor.go
Normal file
40
kv/compactor.go
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kv
|
||||
|
||||
import "time"
|
||||
|
||||
// CompactPolicy defines gc policy of MVCC storage.
|
||||
type CompactPolicy struct {
|
||||
// SafePoint specifies
|
||||
SafePoint int
|
||||
// TriggerInterval specifies how often should the compactor
|
||||
// scans outdated data.
|
||||
TriggerInterval time.Duration
|
||||
// BatchDeleteCnt specifies the batch size for
|
||||
// deleting outdated data transaction.
|
||||
BatchDeleteCnt int
|
||||
}
|
||||
|
||||
// Compactor compacts MVCC storage.
|
||||
type Compactor interface {
|
||||
// OnGet is the hook point on Txn.Get.
|
||||
OnGet(k Key)
|
||||
// OnSet is the hook point on Txn.Set.
|
||||
OnSet(k Key)
|
||||
// OnDelete is the hook point on Txn.Delete.
|
||||
OnDelete(k Key)
|
||||
// Compact is the function removes the given key.
|
||||
Compact(ctx interface{}, k Key) error
|
||||
}
|
||||
@ -87,7 +87,7 @@ func (c *indexIter) Next() (k []interface{}, h int64, err error) {
|
||||
k = vv
|
||||
}
|
||||
// update new iter to next
|
||||
newIt, err := c.it.Next(hasPrefix([]byte(c.prefix)))
|
||||
newIt, err := c.it.Next()
|
||||
if err != nil {
|
||||
return nil, 0, errors.Trace(err)
|
||||
}
|
||||
@ -189,16 +189,10 @@ func (c *kvIndex) Delete(txn Transaction, indexedValues []interface{}, h int64)
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
func hasPrefix(prefix []byte) FnKeyCmp {
|
||||
return func(k Key) bool {
|
||||
return bytes.HasPrefix([]byte(k), prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Drop removes the KV index from store.
|
||||
func (c *kvIndex) Drop(txn Transaction) error {
|
||||
prefix := []byte(c.prefix)
|
||||
it, err := txn.Seek(Key(prefix), hasPrefix(prefix))
|
||||
it, err := txn.Seek(Key(prefix))
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -213,7 +207,7 @@ func (c *kvIndex) Drop(txn Transaction) error {
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
it, err = it.Next(hasPrefix(prefix))
|
||||
it, err = it.Next()
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -227,7 +221,7 @@ func (c *kvIndex) Seek(txn Transaction, indexedValues []interface{}) (iter Index
|
||||
if err != nil {
|
||||
return nil, false, errors.Trace(err)
|
||||
}
|
||||
it, err := txn.Seek(keyBuf, hasPrefix([]byte(c.prefix)))
|
||||
it, err := txn.Seek(keyBuf)
|
||||
if err != nil {
|
||||
return nil, false, errors.Trace(err)
|
||||
}
|
||||
@ -242,7 +236,7 @@ func (c *kvIndex) Seek(txn Transaction, indexedValues []interface{}) (iter Index
|
||||
// SeekFirst returns an iterator which points to the first entry of the KV index.
|
||||
func (c *kvIndex) SeekFirst(txn Transaction) (iter IndexIterator, err error) {
|
||||
prefix := []byte(c.prefix)
|
||||
it, err := txn.Seek(prefix, hasPrefix(prefix))
|
||||
it, err := txn.Seek(prefix)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -72,7 +72,7 @@ func DecodeValue(data []byte) ([]interface{}, error) {
|
||||
func NextUntil(it Iterator, fn FnKeyCmp) (Iterator, error) {
|
||||
var err error
|
||||
for it.Valid() && !fn([]byte(it.Key())) {
|
||||
it, err = it.Next(nil)
|
||||
it, err = it.Next()
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
4
kv/kv.go
4
kv/kv.go
@ -104,7 +104,7 @@ type Transaction interface {
|
||||
// Set sets the value for key k as v into KV store.
|
||||
Set(k Key, v []byte) error
|
||||
// Seek searches for the entry with key k in KV store.
|
||||
Seek(k Key, fnKeyCmp func(key Key) bool) (Iterator, error)
|
||||
Seek(k Key) (Iterator, error)
|
||||
// Inc increases the value for key k in KV store by step.
|
||||
Inc(k Key, step int64) (int64, error)
|
||||
// GetInt64 get int64 which created by Inc method.
|
||||
@ -171,7 +171,7 @@ type FnKeyCmp func(key Key) bool
|
||||
|
||||
// Iterator is the interface for a interator on KV store.
|
||||
type Iterator interface {
|
||||
Next(FnKeyCmp) (Iterator, error)
|
||||
Next() (Iterator, error)
|
||||
Value() []byte
|
||||
Key() string
|
||||
Valid() bool
|
||||
|
||||
@ -51,7 +51,7 @@ func (iter *UnionIter) dirtyNext() {
|
||||
|
||||
// Go next and update valid status.
|
||||
func (iter *UnionIter) snapshotNext() {
|
||||
iter.snapshotIt, _ = iter.snapshotIt.Next(nil)
|
||||
iter.snapshotIt, _ = iter.snapshotIt.Next()
|
||||
iter.snapshotValid = iter.snapshotIt.Valid()
|
||||
}
|
||||
|
||||
@ -116,7 +116,7 @@ func (iter *UnionIter) updateCur() {
|
||||
}
|
||||
|
||||
// Next implements the Iterator Next interface.
|
||||
func (iter *UnionIter) Next(f FnKeyCmp) (Iterator, error) {
|
||||
func (iter *UnionIter) Next() (Iterator, error) {
|
||||
if !iter.curIsDirty {
|
||||
iter.snapshotNext()
|
||||
} else {
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import . "github.com/pingcap/check"
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
// CharsetIDs maps charset name to its default collation ID.
|
||||
var CharsetIDs = map[string]uint8{
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
// Version informations.
|
||||
const (
|
||||
@ -57,7 +57,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
// Decimal implements an arbitrary precision fixed-point decimal.
|
||||
//
|
||||
@ -57,7 +57,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
. "github.com/pingcap/check"
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
// MySQL error code.
|
||||
// This value is numeric. It is not portable to other database systems.
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
// MySQLErrName maps error code to MySQL error messages.
|
||||
var MySQLErrName = map[uint16]string{
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
. "github.com/pingcap/check"
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"math"
|
||||
@ -80,3 +80,13 @@ func parseFrac(s string, fsp int) (int, error) {
|
||||
// 0.0312 round 2 -> 3 -> 30000
|
||||
return int(round * math.Pow10(MaxFsp-fsp)), nil
|
||||
}
|
||||
|
||||
// alignFrac is used to generate alignment frac, like `100` -> `100000`
|
||||
func alignFrac(s string, fsp int) string {
|
||||
sl := len(s)
|
||||
if sl < fsp {
|
||||
return s + strings.Repeat("0", fsp-sl)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
. "github.com/pingcap/check"
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
const (
|
||||
// DefaultMySQLState is default state of the mySQL
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -979,3 +979,364 @@ func checkTimestamp(t Time) bool {
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ExtractTimeNum extracts time value number from time unit and format.
|
||||
func ExtractTimeNum(unit string, t Time) (int64, error) {
|
||||
switch strings.ToUpper(unit) {
|
||||
case "MICROSECOND":
|
||||
return int64(t.Nanosecond() / 1000), nil
|
||||
case "SECOND":
|
||||
return int64(t.Second()), nil
|
||||
case "MINUTE":
|
||||
return int64(t.Minute()), nil
|
||||
case "HOUR":
|
||||
return int64(t.Hour()), nil
|
||||
case "DAY":
|
||||
return int64(t.Day()), nil
|
||||
case "WEEK":
|
||||
_, week := t.ISOWeek()
|
||||
return int64(week), nil
|
||||
case "MONTH":
|
||||
return int64(t.Month()), nil
|
||||
case "QUARTER":
|
||||
m := int64(t.Month())
|
||||
// 1 - 3 -> 1
|
||||
// 4 - 6 -> 2
|
||||
// 7 - 9 -> 3
|
||||
// 10 - 12 -> 4
|
||||
return (m + 2) / 3, nil
|
||||
case "YEAR":
|
||||
return int64(t.Year()), nil
|
||||
case "SECOND_MICROSECOND":
|
||||
return int64(t.Second())*1000000 + int64(t.Nanosecond())/1000, nil
|
||||
case "MINUTE_MICROSECOND":
|
||||
_, m, s := t.Clock()
|
||||
return int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil
|
||||
case "MINUTE_SECOND":
|
||||
_, m, s := t.Clock()
|
||||
return int64(m*100 + s), nil
|
||||
case "HOUR_MICROSECOND":
|
||||
h, m, s := t.Clock()
|
||||
return int64(h)*10000000000 + int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil
|
||||
case "HOUR_SECOND":
|
||||
h, m, s := t.Clock()
|
||||
return int64(h)*10000 + int64(m)*100 + int64(s), nil
|
||||
case "HOUR_MINUTE":
|
||||
h, m, _ := t.Clock()
|
||||
return int64(h)*100 + int64(m), nil
|
||||
case "DAY_MICROSECOND":
|
||||
h, m, s := t.Clock()
|
||||
d := t.Day()
|
||||
return int64(d*1000000+h*10000+m*100+s)*1000000 + int64(t.Nanosecond())/1000, nil
|
||||
case "DAY_SECOND":
|
||||
h, m, s := t.Clock()
|
||||
d := t.Day()
|
||||
return int64(d)*1000000 + int64(h)*10000 + int64(m)*100 + int64(s), nil
|
||||
case "DAY_MINUTE":
|
||||
h, m, _ := t.Clock()
|
||||
d := t.Day()
|
||||
return int64(d)*10000 + int64(h)*100 + int64(m), nil
|
||||
case "DAY_HOUR":
|
||||
h, _, _ := t.Clock()
|
||||
d := t.Day()
|
||||
return int64(d)*100 + int64(h), nil
|
||||
case "YEAR_MONTH":
|
||||
y, m, _ := t.Date()
|
||||
return int64(y)*100 + int64(m), nil
|
||||
default:
|
||||
return 0, errors.Errorf("invalid unit %s", unit)
|
||||
}
|
||||
}
|
||||
|
||||
func extractSingleTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) {
|
||||
iv, err := strconv.ParseInt(format, 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
v := time.Duration(iv)
|
||||
switch strings.ToUpper(unit) {
|
||||
case "MICROSECOND":
|
||||
return 0, 0, 0, v * time.Microsecond, nil
|
||||
case "SECOND":
|
||||
return 0, 0, 0, v * time.Second, nil
|
||||
case "MINUTE":
|
||||
return 0, 0, 0, v * time.Minute, nil
|
||||
case "HOUR":
|
||||
return 0, 0, 0, v * time.Hour, nil
|
||||
case "DAY":
|
||||
return 0, 0, iv, 0, nil
|
||||
case "WEEK":
|
||||
return 0, 0, 7 * iv, 0, nil
|
||||
case "MONTH":
|
||||
return 0, iv, 0, 0, nil
|
||||
case "QUARTER":
|
||||
return 0, 3 * iv, 0, 0, nil
|
||||
case "YEAR":
|
||||
return iv, 0, 0, 0, nil
|
||||
}
|
||||
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit)
|
||||
}
|
||||
|
||||
// Format is `SS.FFFFFF`.
|
||||
func extractSecondMicrosecond(format string) (int64, int64, int64, time.Duration, error) {
|
||||
fields := strings.Split(format, ".")
|
||||
if len(fields) != 2 {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
seconds, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
microseconds, err := strconv.ParseInt(alignFrac(fields[1], MaxFsp), 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
return 0, 0, 0, time.Duration(seconds)*time.Second + time.Duration(microseconds)*time.Microsecond, nil
|
||||
}
|
||||
|
||||
// Format is `MM:SS.FFFFFF`.
|
||||
func extractMinuteMicrosecond(format string) (int64, int64, int64, time.Duration, error) {
|
||||
fields := strings.Split(format, ":")
|
||||
if len(fields) != 2 {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
minutes, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
_, _, _, value, err := extractSecondMicrosecond(fields[1])
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Trace(err)
|
||||
}
|
||||
|
||||
return 0, 0, 0, time.Duration(minutes)*time.Minute + value, nil
|
||||
}
|
||||
|
||||
// Format is `MM:SS`.
|
||||
func extractMinuteSecond(format string) (int64, int64, int64, time.Duration, error) {
|
||||
fields := strings.Split(format, ":")
|
||||
if len(fields) != 2 {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
minutes, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
seconds, err := strconv.ParseInt(fields[1], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
return 0, 0, 0, time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil
|
||||
}
|
||||
|
||||
// Format is `HH:MM:SS.FFFFFF`.
|
||||
func extractHourMicrosecond(format string) (int64, int64, int64, time.Duration, error) {
|
||||
fields := strings.Split(format, ":")
|
||||
if len(fields) != 3 {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
hours, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
minutes, err := strconv.ParseInt(fields[1], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
_, _, _, value, err := extractSecondMicrosecond(fields[2])
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Trace(err)
|
||||
}
|
||||
|
||||
return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + value, nil
|
||||
}
|
||||
|
||||
// Format is `HH:MM:SS`.
|
||||
func extractHourSecond(format string) (int64, int64, int64, time.Duration, error) {
|
||||
fields := strings.Split(format, ":")
|
||||
if len(fields) != 3 {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
hours, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
minutes, err := strconv.ParseInt(fields[1], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
seconds, err := strconv.ParseInt(fields[2], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil
|
||||
}
|
||||
|
||||
// Format is `HH:MM`.
|
||||
func extractHourMinute(format string) (int64, int64, int64, time.Duration, error) {
|
||||
fields := strings.Split(format, ":")
|
||||
if len(fields) != 2 {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
hours, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
minutes, err := strconv.ParseInt(fields[1], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute, nil
|
||||
}
|
||||
|
||||
// Format is `DD HH:MM:SS.FFFFFF`.
|
||||
func extractDayMicrosecond(format string) (int64, int64, int64, time.Duration, error) {
|
||||
fields := strings.Split(format, " ")
|
||||
if len(fields) != 2 {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
days, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
_, _, _, value, err := extractHourMicrosecond(fields[1])
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
return 0, 0, days, value, nil
|
||||
}
|
||||
|
||||
// Format is `DD HH:MM:SS`.
|
||||
func extractDaySecond(format string) (int64, int64, int64, time.Duration, error) {
|
||||
fields := strings.Split(format, " ")
|
||||
if len(fields) != 2 {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
days, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
_, _, _, value, err := extractHourSecond(fields[1])
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
return 0, 0, days, value, nil
|
||||
}
|
||||
|
||||
// Format is `DD HH:MM`.
|
||||
func extractDayMinute(format string) (int64, int64, int64, time.Duration, error) {
|
||||
fields := strings.Split(format, " ")
|
||||
if len(fields) != 2 {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
days, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
_, _, _, value, err := extractHourMinute(fields[1])
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
return 0, 0, days, value, nil
|
||||
}
|
||||
|
||||
// Format is `DD HH`.
|
||||
func extractDayHour(format string) (int64, int64, int64, time.Duration, error) {
|
||||
fields := strings.Split(format, " ")
|
||||
if len(fields) != 2 {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
days, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
hours, err := strconv.ParseInt(fields[1], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
return 0, 0, days, time.Duration(hours) * time.Hour, nil
|
||||
}
|
||||
|
||||
// Format is `YYYY-MM`.
|
||||
func extractYearMonth(format string) (int64, int64, int64, time.Duration, error) {
|
||||
fields := strings.Split(format, "-")
|
||||
if len(fields) != 2 {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
years, err := strconv.ParseInt(fields[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
months, err := strconv.ParseInt(fields[1], 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format)
|
||||
}
|
||||
|
||||
return years, months, 0, 0, nil
|
||||
}
|
||||
|
||||
// ExtractTimeValue extracts time value from time unit and format.
|
||||
func ExtractTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) {
|
||||
switch strings.ToUpper(unit) {
|
||||
case "MICROSECOND", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR":
|
||||
return extractSingleTimeValue(unit, format)
|
||||
case "SECOND_MICROSECOND":
|
||||
return extractSecondMicrosecond(format)
|
||||
case "MINUTE_MICROSECOND":
|
||||
return extractMinuteMicrosecond(format)
|
||||
case "MINUTE_SECOND":
|
||||
return extractMinuteSecond(format)
|
||||
case "HOUR_MICROSECOND":
|
||||
return extractHourMicrosecond(format)
|
||||
case "HOUR_SECOND":
|
||||
return extractHourSecond(format)
|
||||
case "HOUR_MINUTE":
|
||||
return extractHourMinute(format)
|
||||
case "DAY_MICROSECOND":
|
||||
return extractDayMicrosecond(format)
|
||||
case "DAY_SECOND":
|
||||
return extractDaySecond(format)
|
||||
case "DAY_MINUTE":
|
||||
return extractDayMinute(format)
|
||||
case "DAY_HOUR":
|
||||
return extractDayHour(format)
|
||||
case "YEAR_MONTH":
|
||||
return extractYearMonth(format)
|
||||
default:
|
||||
return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit)
|
||||
}
|
||||
}
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
// MySQL type informations.
|
||||
const (
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import . "github.com/pingcap/check"
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
// GetDefaultFieldLength is used for Interger Types, Flen is the display length.
|
||||
// Call this when no Flen assigned in ddl.
|
||||
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqldef
|
||||
package mysql
|
||||
|
||||
import "testing"
|
||||
|
||||
@ -21,7 +21,7 @@ import (
|
||||
"github.com/pingcap/tidb/column"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/table"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
|
||||
@ -18,7 +18,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/pingcap/tidb/expression"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
// FloatOpt is used for parsing floating-point type option from SQL.
|
||||
|
||||
@ -29,7 +29,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/parser/coldef"
|
||||
"github.com/pingcap/tidb/ddl"
|
||||
@ -113,6 +113,7 @@ import (
|
||||
currentUser "CURRENT_USER"
|
||||
database "DATABASE"
|
||||
databases "DATABASES"
|
||||
dateAdd "DATE_ADD"
|
||||
day "DAY"
|
||||
dayofmonth "DAYOFMONTH"
|
||||
dayofweek "DAYOFWEEK"
|
||||
@ -164,6 +165,7 @@ import (
|
||||
index "INDEX"
|
||||
inner "INNER"
|
||||
insert "INSERT"
|
||||
interval "INTERVAL"
|
||||
into "INTO"
|
||||
is "IS"
|
||||
join "JOIN"
|
||||
@ -707,7 +709,6 @@ ColumnPosition:
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
AlterSpecificationList:
|
||||
AlterSpecification
|
||||
{
|
||||
@ -1702,7 +1703,7 @@ UnReservedKeyword:
|
||||
| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE"
|
||||
|
||||
NotKeywordToken:
|
||||
"ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT"
|
||||
"ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT"
|
||||
| "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS"
|
||||
| "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK"
|
||||
|
||||
@ -2239,6 +2240,14 @@ FunctionCallNonKeyword:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
| "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')'
|
||||
{
|
||||
$$ = &expression.DateAdd{
|
||||
Unit: $7.(string),
|
||||
Date: $3.(expression.Expression),
|
||||
Interval: $6.(expression.Expression),
|
||||
}
|
||||
}
|
||||
| "EXTRACT" '(' TimeUnit "FROM" Expression ')'
|
||||
{
|
||||
$$ = &expression.Extract{
|
||||
|
||||
@ -403,6 +403,28 @@ func (s *testParserSuite) TestBuiltin(c *C) {
|
||||
{`SELECT TRIM(LEADING 'x' FROM 'xxxbarxxx');`, true},
|
||||
{`SELECT TRIM(BOTH 'x' FROM 'xxxbarxxx');`, true},
|
||||
{`SELECT TRIM(TRAILING 'xyz' FROM 'barxxyz');`, true},
|
||||
|
||||
// For date_add
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval 10 second)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval 10 minute)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval 10 hour)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval 10 day)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval 1 week)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval 1 month)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval 1 quarter)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval 1 year)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true},
|
||||
{`select date_add("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true},
|
||||
}
|
||||
s.RunTest(c, table)
|
||||
}
|
||||
|
||||
@ -27,7 +27,7 @@ import (
|
||||
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/util/stringutil"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
type lexer struct {
|
||||
@ -284,6 +284,7 @@ current_date {c}{u}{r}{r}{e}{n}{t}_{d}{a}{t}{e}
|
||||
current_user {c}{u}{r}{r}{e}{n}{t}_{u}{s}{e}{r}
|
||||
database {d}{a}{t}{a}{b}{a}{s}{e}
|
||||
databases {d}{a}{t}{a}{b}{a}{s}{e}{s}
|
||||
date_add {d}{a}{t}{e}_{a}{d}{d}
|
||||
day {d}{a}{y}
|
||||
dayofweek {d}{a}{y}{o}{f}{w}{e}{e}{k}
|
||||
dayofmonth {d}{a}{y}{o}{f}{m}{o}{n}{t}{h}
|
||||
@ -331,6 +332,7 @@ in {i}{n}
|
||||
index {i}{n}{d}{e}{x}
|
||||
inner {i}{n}{n}{e}{r}
|
||||
insert {i}{n}{s}{e}{r}{t}
|
||||
interval {i}{n}{t}{e}{r}{v}{a}{l}
|
||||
into {i}{n}{t}{o}
|
||||
is {i}{s}
|
||||
join {j}{o}{i}{n}
|
||||
@ -630,6 +632,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h}
|
||||
{database} lval.item = string(l.val)
|
||||
return database
|
||||
{databases} return databases
|
||||
{date_add} lval.item = string(l.val)
|
||||
return dateAdd
|
||||
{day} lval.item = string(l.val)
|
||||
return day
|
||||
{dayofweek} lval.item = string(l.val)
|
||||
@ -711,6 +715,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h}
|
||||
{index} return index
|
||||
{inner} return inner
|
||||
{insert} return insert
|
||||
{interval} return interval
|
||||
{into} return into
|
||||
{in} return in
|
||||
{is} return is
|
||||
|
||||
@ -64,10 +64,9 @@ func (r *SelectFieldsDefaultPlan) Next(ctx context.Context) (row *plan.Row, err
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
v, err0 := GetIdentValue(name, r.Src.GetFields(), srcRow.Data, field.DefaultFieldFlag)
|
||||
if err0 == nil {
|
||||
return v, nil
|
||||
r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
|
||||
if scope == expression.IdentReferFromTable {
|
||||
return srcRow.FromData[index], nil
|
||||
}
|
||||
|
||||
return getIdentValueFromOuterQuery(ctx, name)
|
||||
|
||||
@ -18,7 +18,7 @@ import (
|
||||
"github.com/pingcap/tidb/column"
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/plan/plans"
|
||||
"github.com/pingcap/tidb/rset/rsets"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
|
||||
@ -68,7 +68,7 @@ func (r *TableNilPlan) Next(ctx context.Context) (row *plan.Row, err error) {
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
r.iter, err = txn.Seek([]byte(r.T.FirstKey()), nil)
|
||||
r.iter, err = txn.Seek([]byte(r.T.FirstKey()))
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
@ -274,7 +274,7 @@ func (r *TableDefaultPlan) Next(ctx context.Context) (row *plan.Row, err error)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
r.iter, err = txn.Seek([]byte(r.T.FirstKey()), nil)
|
||||
r.iter, err = txn.Seek([]byte(r.T.FirstKey()))
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -24,7 +24,7 @@ import (
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/plan/plans"
|
||||
"github.com/pingcap/tidb/rset/rsets"
|
||||
|
||||
@ -89,15 +89,11 @@ type groupRow struct {
|
||||
func (r *GroupByDefaultPlan) evalGroupKey(ctx context.Context, k []interface{}, outRow []interface{}, in []interface{}) error {
|
||||
// group by items can not contain aggregate field, so we can eval them safely.
|
||||
m := map[interface{}]interface{}{}
|
||||
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
v, err = r.getFieldValueByName(name, outRow)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
|
||||
if scope == expression.IdentReferFromTable {
|
||||
return in[index], nil
|
||||
} else if scope == expression.IdentReferSelectList {
|
||||
return outRow[index], nil
|
||||
}
|
||||
|
||||
// try to find in outer query
|
||||
@ -132,10 +128,11 @@ func (r *GroupByDefaultPlan) getFieldValueByName(name string, out []interface{})
|
||||
|
||||
func (r *GroupByDefaultPlan) evalNoneAggFields(ctx context.Context, out []interface{},
|
||||
m map[interface{}]interface{}, in []interface{}) error {
|
||||
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
|
||||
if scope == expression.IdentReferFromTable {
|
||||
return in[index], nil
|
||||
} else if scope == expression.IdentReferSelectList {
|
||||
return out[index], nil
|
||||
}
|
||||
|
||||
// try to find in outer query
|
||||
@ -157,42 +154,20 @@ func (r *GroupByDefaultPlan) evalNoneAggFields(ctx context.Context, out []interf
|
||||
|
||||
func (r *GroupByDefaultPlan) evalAggFields(ctx context.Context, out []interface{},
|
||||
m map[interface{}]interface{}, in []interface{}) error {
|
||||
// Eval aggregate field results in ctx
|
||||
for i := range r.AggFields {
|
||||
if i < r.HiddenFieldOffset {
|
||||
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// try to find in outer query
|
||||
return getIdentValueFromOuterQuery(ctx, name)
|
||||
}
|
||||
} else {
|
||||
// having may contain aggregate function and we will add it to hidden field,
|
||||
// and this field can retrieve the data in select list, e.g.
|
||||
// select c1 as a from t having count(a) = 1
|
||||
// because all the select list data is generated before, so we can get it
|
||||
// when handling hidden field.
|
||||
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// if we can not find in table, we will try to find in un-hidden select list
|
||||
// only hidden field can use this
|
||||
v, err = r.getFieldValueByName(name, out)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// try to find in outer query
|
||||
return getIdentValueFromOuterQuery(ctx, name)
|
||||
}
|
||||
m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
|
||||
if scope == expression.IdentReferFromTable {
|
||||
return in[index], nil
|
||||
} else if scope == expression.IdentReferSelectList {
|
||||
return out[index], nil
|
||||
}
|
||||
|
||||
// try to find in outer query
|
||||
return getIdentValueFromOuterQuery(ctx, name)
|
||||
}
|
||||
|
||||
// Eval aggregate field results in ctx
|
||||
for i := range r.AggFields {
|
||||
// we must evaluate aggregate function only, e.g, select col1 + count(*) in (count(*)),
|
||||
// we cannot evaluate it directly here, because col1 + count(*) returns nil before AggDone phase,
|
||||
// so we don't evaluate count(*) in In expression, and will get an invalid data in AggDone phase for it.
|
||||
|
||||
@ -42,12 +42,16 @@ func (t *testGroupBySuite) TestGroupBy(c *C) {
|
||||
Fields: []*field.Field{
|
||||
{
|
||||
Expr: &expression.Ident{
|
||||
CIStr: model.NewCIStr("id"),
|
||||
CIStr: model.NewCIStr("id"),
|
||||
ReferScope: expression.IdentReferFromTable,
|
||||
ReferIndex: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
Expr: &expression.Ident{
|
||||
CIStr: model.NewCIStr("name"),
|
||||
CIStr: model.NewCIStr("name"),
|
||||
ReferScope: expression.IdentReferFromTable,
|
||||
ReferIndex: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -55,7 +59,9 @@ func (t *testGroupBySuite) TestGroupBy(c *C) {
|
||||
F: "sum",
|
||||
Args: []expression.Expression{
|
||||
&expression.Ident{
|
||||
CIStr: model.NewCIStr("id"),
|
||||
CIStr: model.NewCIStr("id"),
|
||||
ReferScope: expression.IdentReferFromTable,
|
||||
ReferIndex: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -69,7 +75,9 @@ func (t *testGroupBySuite) TestGroupBy(c *C) {
|
||||
Src: tblPlan,
|
||||
By: []expression.Expression{
|
||||
&expression.Ident{
|
||||
CIStr: model.NewCIStr("id"),
|
||||
CIStr: model.NewCIStr("id"),
|
||||
ReferScope: expression.IdentReferFromTable,
|
||||
ReferIndex: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -62,10 +62,11 @@ func (r *HavingPlan) Next(ctx context.Context) (row *plan.Row, err error) {
|
||||
if srcRow == nil || err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
v, err0 := GetIdentValue(name, r.Src.GetFields(), srcRow.Data, field.CheckFieldFlag)
|
||||
if err0 == nil {
|
||||
return v, nil
|
||||
r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
|
||||
if scope == expression.IdentReferFromTable {
|
||||
return srcRow.FromData[index], nil
|
||||
} else if scope == expression.IdentReferSelectList {
|
||||
return srcRow.Data[index], nil
|
||||
}
|
||||
|
||||
// try to find in outer query
|
||||
|
||||
@ -38,7 +38,9 @@ func (t *testHavingPlan) TestHaving(c *C) {
|
||||
Expr: &expression.BinaryOperation{
|
||||
Op: opcode.GE,
|
||||
L: &expression.Ident{
|
||||
CIStr: model.NewCIStr("id"),
|
||||
CIStr: model.NewCIStr("id"),
|
||||
ReferScope: expression.IdentReferSelectList,
|
||||
ReferIndex: 0,
|
||||
},
|
||||
R: &expression.Value{
|
||||
Val: 20,
|
||||
|
||||
@ -24,7 +24,7 @@ import (
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/plan/plans"
|
||||
"github.com/pingcap/tidb/rset/rsets"
|
||||
|
||||
@ -25,7 +25,7 @@ import (
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/infoschema"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/plan"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
|
||||
@ -155,6 +155,7 @@ func (r *JoinPlan) Next(ctx context.Context) (row *plan.Row, err error) {
|
||||
}
|
||||
|
||||
func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error) {
|
||||
visitor := newJoinIdentVisitor(0)
|
||||
for {
|
||||
if r.cursor < len(r.matchedRows) {
|
||||
row = r.matchedRows[r.cursor]
|
||||
@ -167,7 +168,7 @@ func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error)
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
tempExpr := r.On.Clone()
|
||||
visitor := NewIdentEvalVisitor(r.Left.GetFields(), leftRow.Data)
|
||||
visitor.row = leftRow.Data
|
||||
_, err = tempExpr.Accept(visitor)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
@ -189,6 +190,7 @@ func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error)
|
||||
}
|
||||
|
||||
func (r *JoinPlan) nextRightJoin(ctx context.Context) (row *plan.Row, err error) {
|
||||
visitor := newJoinIdentVisitor(len(r.Left.GetFields()))
|
||||
for {
|
||||
if r.cursor < len(r.matchedRows) {
|
||||
row = r.matchedRows[r.cursor]
|
||||
@ -202,7 +204,7 @@ func (r *JoinPlan) nextRightJoin(ctx context.Context) (row *plan.Row, err error)
|
||||
}
|
||||
|
||||
tempExpr := r.On.Clone()
|
||||
visitor := NewIdentEvalVisitor(r.Right.GetFields(), rightRow.Data)
|
||||
visitor.row = rightRow.Data
|
||||
_, err = tempExpr.Accept(visitor)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
@ -251,8 +253,8 @@ func (r *JoinPlan) findMatchedRows(ctx context.Context, row *plan.Row, p plan.Pl
|
||||
} else {
|
||||
joined = append(append(joined, row.Data...), cmpRow.Data...)
|
||||
}
|
||||
r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
return GetIdentValue(name, r.Fields, joined, field.DefaultFieldFlag)
|
||||
r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
|
||||
return joined[index], nil
|
||||
}
|
||||
var b bool
|
||||
b, err = expression.EvalBoolExpr(ctx, r.On, r.evalArgs)
|
||||
@ -280,6 +282,7 @@ func (r *JoinPlan) findMatchedRows(ctx context.Context, row *plan.Row, p plan.Pl
|
||||
}
|
||||
|
||||
func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error) {
|
||||
visitor := newJoinIdentVisitor(0)
|
||||
for {
|
||||
if r.curRow == nil {
|
||||
r.curRow, err = r.Left.Next(ctx)
|
||||
@ -292,7 +295,7 @@ func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error)
|
||||
|
||||
if r.On != nil {
|
||||
tempExpr := r.On.Clone()
|
||||
visitor := NewIdentEvalVisitor(r.Left.GetFields(), r.curRow.Data)
|
||||
visitor.row = r.curRow.Data
|
||||
_, err = tempExpr.Accept(visitor)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
@ -323,8 +326,8 @@ func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error)
|
||||
joinedRow = append(append(joinedRow, r.curRow.Data...), rightRow.Data...)
|
||||
|
||||
if r.On != nil {
|
||||
r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
return GetIdentValue(name, r.Fields, joinedRow, field.DefaultFieldFlag)
|
||||
r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
|
||||
return joinedRow[index], nil
|
||||
}
|
||||
|
||||
b, err := expression.EvalBoolExpr(ctx, r.On, r.evalArgs)
|
||||
@ -356,32 +359,34 @@ func (r *JoinPlan) Close() error {
|
||||
return r.Left.Close()
|
||||
}
|
||||
|
||||
// IdentEvalVisitor converts Ident expression to value expression.
|
||||
type IdentEvalVisitor struct {
|
||||
// joinIdentVisitor converts Ident expression to value expression in ON expression.
|
||||
type joinIdentVisitor struct {
|
||||
expression.BaseVisitor
|
||||
fields []*field.ResultField
|
||||
row []interface{}
|
||||
offset int
|
||||
}
|
||||
|
||||
// NewIdentEvalVisitor creates a new IdentEvalVisitor.
|
||||
func NewIdentEvalVisitor(fields []*field.ResultField, row []interface{}) *IdentEvalVisitor {
|
||||
iev := &IdentEvalVisitor{fields: fields, row: row}
|
||||
// newJoinIdentVisitor creates a new joinIdentVisitor.
|
||||
func newJoinIdentVisitor(offset int) *joinIdentVisitor {
|
||||
iev := &joinIdentVisitor{offset: offset}
|
||||
iev.BaseVisitor.V = iev
|
||||
return iev
|
||||
}
|
||||
|
||||
// VisitIdent implements Visitor interface.
|
||||
func (iev *IdentEvalVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
|
||||
v, err := GetIdentValue(i.L, iev.fields, iev.row, field.CheckFieldFlag)
|
||||
if err != nil {
|
||||
func (iev *joinIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
|
||||
// the row here may be just left part or right part, but identifier may reference another part,
|
||||
// so here we must check its reference index validation.
|
||||
if i.ReferIndex < iev.offset || i.ReferIndex >= iev.offset+len(iev.row) {
|
||||
return i, nil
|
||||
}
|
||||
v := iev.row[i.ReferIndex-iev.offset]
|
||||
return expression.Value{Val: v}, nil
|
||||
}
|
||||
|
||||
// VisitBinaryOperation swaps the right side identifier to left side if left side expression is static.
|
||||
// So it can be used in index plan.
|
||||
func (iev *IdentEvalVisitor) VisitBinaryOperation(binop *expression.BinaryOperation) (expression.Expression, error) {
|
||||
func (iev *joinIdentVisitor) VisitBinaryOperation(binop *expression.BinaryOperation) (expression.Expression, error) {
|
||||
var err error
|
||||
binop.L, err = binop.L.Accept(iev)
|
||||
if err != nil {
|
||||
|
||||
@ -20,7 +20,7 @@ import (
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/plan/plans"
|
||||
"github.com/pingcap/tidb/rset/rsets"
|
||||
|
||||
@ -158,10 +158,11 @@ func (r *OrderByDefaultPlan) fetchAll(ctx context.Context) error {
|
||||
if row == nil {
|
||||
break
|
||||
}
|
||||
evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
v, err := GetIdentValue(name, r.ResultFields, row.Data, field.CheckFieldFlag)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
|
||||
if scope == expression.IdentReferFromTable {
|
||||
return row.FromData[index], nil
|
||||
} else if scope == expression.IdentReferSelectList {
|
||||
return row.Data[index], nil
|
||||
}
|
||||
|
||||
// try to find in outer query
|
||||
|
||||
@ -48,7 +48,9 @@ func (t *testOrderBySuit) TestOrderBy(c *C) {
|
||||
Src: tblPlan,
|
||||
By: []expression.Expression{
|
||||
&expression.Ident{
|
||||
CIStr: model.NewCIStr("id"),
|
||||
CIStr: model.NewCIStr("id"),
|
||||
ReferScope: expression.IdentReferSelectList,
|
||||
ReferIndex: 0,
|
||||
},
|
||||
},
|
||||
Ascs: []bool{false},
|
||||
|
||||
@ -22,7 +22,7 @@ import (
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/field"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/plan"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
"github.com/pingcap/tidb/util/format"
|
||||
|
||||
@ -199,5 +199,5 @@ func getIdentValueFromOuterQuery(ctx context.Context, name string) (interface{},
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.Trace(err)
|
||||
return nil, errors.Errorf("unknown field %s", name)
|
||||
}
|
||||
|
||||
@ -95,63 +95,43 @@ func (s *SelectList) GetFields() []*field.ResultField {
|
||||
}
|
||||
|
||||
// UpdateAggFields adds aggregate function resultfield to select result field list.
|
||||
func (s *SelectList) UpdateAggFields(expr expression.Expression, tableFields []*field.ResultField) (expression.Expression, error) {
|
||||
// For aggregate function, the name can be in table or select list.
|
||||
names := expression.MentionedColumns(expr)
|
||||
|
||||
for _, name := range names {
|
||||
if field.ContainFieldName(name, tableFields, field.DefaultFieldFlag) {
|
||||
continue
|
||||
}
|
||||
|
||||
if field.ContainFieldName(name, s.ResultFields, field.DefaultFieldFlag) {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, errors.Errorf("Unknown column '%s'", name)
|
||||
}
|
||||
|
||||
func (s *SelectList) UpdateAggFields(expr expression.Expression) (expression.Expression, error) {
|
||||
// We must add aggregate function to hidden select list
|
||||
// and use a position expression to fetch its value later.
|
||||
exprName := expr.String()
|
||||
idx := field.GetResultFieldIndex(exprName, s.ResultFields, field.CheckFieldFlag)
|
||||
if len(idx) == 0 {
|
||||
f := &field.Field{Expr: expr}
|
||||
resultField := &field.ResultField{Name: exprName}
|
||||
s.AddField(f, resultField)
|
||||
name := strings.ToLower(expr.String())
|
||||
index := -1
|
||||
for i := 0; i < s.HiddenFieldOffset; i++ {
|
||||
// only check origin name, e,g. "select sum(c1) as a from t order by sum(c1)"
|
||||
// or "select sum(c1) from t order by sum(c1)"
|
||||
if s.ResultFields[i].ColumnInfo.Name.L == name {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &expression.Position{N: len(s.Fields), Name: exprName}, nil
|
||||
if index == -1 {
|
||||
f := &field.Field{Expr: expr}
|
||||
s.AddField(f, nil)
|
||||
|
||||
pos := len(s.Fields)
|
||||
|
||||
s.AggFields[pos-1] = struct{}{}
|
||||
|
||||
return &expression.Position{N: pos, Name: name}, nil
|
||||
}
|
||||
|
||||
// select list has this field, use it directly.
|
||||
return &expression.Position{N: idx[0] + 1, Name: exprName}, nil
|
||||
return &expression.Position{N: index + 1, Name: name}, nil
|
||||
}
|
||||
|
||||
// CloneHiddenField checks and clones field and result field from table fields,
|
||||
// and adds them to hidden field of select list.
|
||||
func (s *SelectList) CloneHiddenField(name string, tableFields []*field.ResultField) bool {
|
||||
// Check and add hidden field.
|
||||
if field.ContainFieldName(name, tableFields, field.CheckFieldFlag) {
|
||||
resultField, _ := field.CloneFieldByName(name, tableFields, field.CheckFieldFlag)
|
||||
f := &field.Field{
|
||||
Expr: &expression.Ident{
|
||||
CIStr: resultField.ColumnInfo.Name,
|
||||
},
|
||||
}
|
||||
s.AddField(f, resultField)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// CheckReferAmbiguous checks whether an identifier reference is ambiguous or not in select list.
|
||||
// CheckAmbiguous checks whether an identifier reference is ambiguous or not in select list.
|
||||
// e,g, "select c1 as a, c2 as a from t group by a" is ambiguous,
|
||||
// but "select c1 as a, c1 as a from t group by a" is not.
|
||||
// For MySQL "select c1 as a, c2 + 1 as a from t group by a" is not ambiguous too,
|
||||
// so we will only check identifier too.
|
||||
// If no ambiguous, -1 means expr refers none in select list, else an index in select list returns.
|
||||
func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error) {
|
||||
// "select c1 as a, c2 + 1 as a from t group by a" is not ambiguous too,
|
||||
// If no ambiguous, -1 means expr refers none in select list, else an index for first match.
|
||||
// CheckAmbiguous will break the check when finding first matching which is not an indentifier,
|
||||
// or an index for an identifier field in the end, -1 means none found.
|
||||
func (s *SelectList) CheckAmbiguous(expr expression.Expression) (int, error) {
|
||||
if _, ok := expr.(*expression.Ident); !ok {
|
||||
return -1, nil
|
||||
}
|
||||
@ -162,14 +142,22 @@ func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
// select c1 as a, 1 as a, c2 as a from t order by a is not ambiguous.
|
||||
// select c1 as a, c2 as a from t order by a is ambiguous.
|
||||
// select 1 as a, c1 as a from t order by a is not ambiguous.
|
||||
// select c1 as a, sum(c1) as a from t group by a is error.
|
||||
// select c1 as a, 1 as a, sum(c1) as a from t group by a is not error.
|
||||
// so we will break the check if matching a none identifier field.
|
||||
lastIndex := -1
|
||||
// only check origin select list, no hidden field.
|
||||
for i := 0; i < s.HiddenFieldOffset; i++ {
|
||||
if !strings.EqualFold(s.ResultFields[i].Name, name) {
|
||||
continue
|
||||
} else if _, ok := s.Fields[i].Expr.(*expression.Ident); !ok {
|
||||
// not identfier, no check
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := s.Fields[i].Expr.(*expression.Ident); !ok {
|
||||
// not identfier, return directly.
|
||||
return i, nil
|
||||
}
|
||||
|
||||
if lastIndex == -1 {
|
||||
@ -179,6 +167,7 @@ func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error
|
||||
}
|
||||
|
||||
// check origin name, e,g. "select c1 as c2, c2 from t group by c2" is ambiguous.
|
||||
|
||||
if s.ResultFields[i].ColumnInfo.Name.L != s.ResultFields[lastIndex].ColumnInfo.Name.L {
|
||||
return -1, errors.Errorf("refer %s is ambiguous", expr)
|
||||
}
|
||||
@ -233,6 +222,7 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: use fromIdentVisitor to cleanup.
|
||||
var result *field.ResultField
|
||||
for _, name := range names {
|
||||
idx := field.GetResultFieldIndex(name, srcFields, field.DefaultFieldFlag)
|
||||
|
||||
@ -96,7 +96,7 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) {
|
||||
sl, err := plans.ResolveSelectList(fs, rs)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
index, err := sl.CheckReferAmbiguous(&expression.Ident{
|
||||
idx, err := sl.CheckAmbiguous(&expression.Ident{
|
||||
CIStr: model.NewCIStr(t.Name),
|
||||
})
|
||||
|
||||
@ -106,6 +106,6 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) {
|
||||
}
|
||||
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(t.Index, Equals, index)
|
||||
c.Assert(t.Index, DeepEquals, idx)
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,7 +25,7 @@ import (
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/plan"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
|
||||
@ -22,7 +22,7 @@ import (
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/plan/plans"
|
||||
"github.com/pingcap/tidb/rset/rsets"
|
||||
|
||||
@ -18,7 +18,7 @@ import (
|
||||
"github.com/pingcap/tidb/column"
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/plan"
|
||||
"github.com/pingcap/tidb/plan/plans"
|
||||
"github.com/pingcap/tidb/rset/rsets"
|
||||
|
||||
@ -56,15 +56,15 @@ func (r *FilterDefaultPlan) Next(ctx context.Context) (row *plan.Row, err error)
|
||||
if row == nil || err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
v, err0 := GetIdentValue(name, r.GetFields(), row.Data, field.DefaultFieldFlag)
|
||||
if err0 == nil {
|
||||
return v, nil
|
||||
r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
|
||||
if scope == expression.IdentReferFromTable {
|
||||
return row.Data[index], nil
|
||||
}
|
||||
|
||||
// try to find in outer query
|
||||
return getIdentValueFromOuterQuery(ctx, name)
|
||||
}
|
||||
|
||||
var meet bool
|
||||
meet, err = r.meetCondition(ctx)
|
||||
if err != nil {
|
||||
|
||||
@ -44,7 +44,9 @@ func (t *testWhereSuit) TestWhere(c *C) {
|
||||
Expr: &expression.BinaryOperation{
|
||||
Op: opcode.GE,
|
||||
L: &expression.Ident{
|
||||
CIStr: model.NewCIStr("id"),
|
||||
CIStr: model.NewCIStr("id"),
|
||||
ReferScope: expression.IdentReferFromTable,
|
||||
ReferIndex: 0,
|
||||
},
|
||||
R: expression.Value{
|
||||
Val: 30,
|
||||
|
||||
@ -39,10 +39,29 @@ type SelectFieldsRset struct {
|
||||
SelectList *plans.SelectList
|
||||
}
|
||||
|
||||
func updateSelectFieldsRefer(selectList *plans.SelectList) error {
|
||||
visitor := newFromIdentVisitor(selectList.FromFields, fieldListClause)
|
||||
|
||||
// we only fix un-hidden fields here, for hidden fields, it should be
|
||||
// handled in their own clause, in order by or having.
|
||||
for i, f := range selectList.Fields[0:selectList.HiddenFieldOffset] {
|
||||
e, err := f.Expr.Accept(visitor)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
selectList.Fields[i].Expr = e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Plan gets SrcPlan/SelectFieldsDefaultPlan.
|
||||
// If all fields are equal to src plan fields, then gets SrcPlan.
|
||||
// Default gets SelectFieldsDefaultPlan.
|
||||
func (r *SelectFieldsRset) Plan(ctx context.Context) (plan.Plan, error) {
|
||||
if err := updateSelectFieldsRefer(r.SelectList); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
fields := r.SelectList.Fields
|
||||
srcFields := r.Src.GetFields()
|
||||
if len(fields) == len(srcFields) {
|
||||
@ -97,6 +116,7 @@ func (r *SelectFromDualRset) Plan(ctx context.Context) (plan.Plan, error) {
|
||||
// field cannot contain identifier
|
||||
for _, f := range r.Fields {
|
||||
if cs := expression.MentionedColumns(f.Expr); len(cs) > 0 {
|
||||
// TODO: check in outer query, like select * from t where t.c = (select c limit 1);
|
||||
return nil, errors.Errorf("Unknown column '%s' in 'field list'", cs[0])
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,6 +69,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) {
|
||||
|
||||
oldFld := s.sr.SelectList.Fields[1]
|
||||
s.sr.SelectList.Fields = s.sr.SelectList.Fields[:1]
|
||||
s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields)
|
||||
|
||||
p, err = s.sr.Plan(nil)
|
||||
c.Assert(err, IsNil)
|
||||
@ -81,6 +82,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) {
|
||||
s.sr.Src = tdp
|
||||
|
||||
s.sr.SelectList.Fields = []*field.Field{fld}
|
||||
s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields)
|
||||
|
||||
p, err = s.sr.Plan(nil)
|
||||
c.Assert(err, IsNil)
|
||||
@ -94,6 +96,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) {
|
||||
|
||||
// cover isConst check, like `select 1, c1 from t`
|
||||
s.sr.SelectList.Fields = []*field.Field{fld, oldFld}
|
||||
s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields)
|
||||
p, err = s.sr.Plan(nil)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
|
||||
@ -39,70 +39,110 @@ type GroupByRset struct {
|
||||
SelectList *plans.SelectList
|
||||
}
|
||||
|
||||
type groupByVisitor struct {
|
||||
expression.BaseVisitor
|
||||
selectList *plans.SelectList
|
||||
rootIdent *expression.Ident
|
||||
}
|
||||
|
||||
func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
|
||||
// Group by ambiguous rule:
|
||||
// select c1 as a, c2 as a from t group by a is ambiguous
|
||||
// select c1 as a, c2 as a from t group by a + 1 is ambiguous
|
||||
// select c1 as c2, c2 from t group by c2 is ambiguous
|
||||
// select c1 as c2, c2 from t group by c2 + 1 is not ambiguous
|
||||
|
||||
var (
|
||||
index int
|
||||
err error
|
||||
)
|
||||
|
||||
if v.rootIdent == i {
|
||||
// The group by is an identifier, we must check it first.
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, groupByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
// first find this identifier in FROM.
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
|
||||
if len(idx) > 0 {
|
||||
i.ReferScope = expression.IdentReferFromTable
|
||||
i.ReferIndex = idx[0]
|
||||
return i, nil
|
||||
}
|
||||
|
||||
if v.rootIdent != i {
|
||||
// This identifier is the part of the group by, check ambiguous here.
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, groupByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
// try to find in select list, we have got index using checkIdent before.
|
||||
if index >= 0 {
|
||||
// group by can not reference aggregate fields
|
||||
if _, ok := v.selectList.AggFields[index]; ok {
|
||||
return nil, errors.Errorf("Reference '%s' not supported (reference to group function)", i)
|
||||
}
|
||||
|
||||
// find in select list
|
||||
i.ReferScope = expression.IdentReferSelectList
|
||||
i.ReferIndex = index
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// TODO: check in out query
|
||||
return i, errors.Errorf("Unknown column '%s' in 'group statement'", i)
|
||||
}
|
||||
|
||||
func (v *groupByVisitor) VisitCall(c *expression.Call) (expression.Expression, error) {
|
||||
ok := expression.IsAggregateFunc(c.F)
|
||||
if ok {
|
||||
return nil, errors.Errorf("group by cannot contain aggregate function %s", c)
|
||||
}
|
||||
|
||||
var err error
|
||||
for i, e := range c.Args {
|
||||
c.Args[i], err = e.Accept(v)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Plan gets GroupByDefaultPlan.
|
||||
func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) {
|
||||
fields := r.SelectList.Fields
|
||||
if err := updateSelectFieldsRefer(r.SelectList); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
r.SelectList.AggFields = GetAggFields(fields)
|
||||
aggFields := r.SelectList.AggFields
|
||||
visitor := &groupByVisitor{}
|
||||
visitor.BaseVisitor.V = visitor
|
||||
visitor.selectList = r.SelectList
|
||||
|
||||
for i, e := range r.By {
|
||||
if v, ok := e.(expression.Value); ok {
|
||||
var position int
|
||||
switch u := v.Val.(type) {
|
||||
case int64:
|
||||
position = int(u)
|
||||
case uint64:
|
||||
position = int(u)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if position < 1 || position > len(fields) {
|
||||
return nil, errors.Errorf("Unknown column '%d' in 'group statement'", position)
|
||||
}
|
||||
|
||||
index := position - 1
|
||||
if _, ok := aggFields[index]; ok {
|
||||
return nil, errors.Errorf("Can't group on '%s'", fields[index])
|
||||
}
|
||||
|
||||
// use Position expression for the associated field.
|
||||
r.By[i] = &expression.Position{N: position}
|
||||
} else {
|
||||
index, err := r.SelectList.CheckReferAmbiguous(e)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("Column '%s' in group statement is ambiguous", e)
|
||||
} else if _, ok := aggFields[index]; ok {
|
||||
return nil, errors.Errorf("Can't group on '%s'", e)
|
||||
}
|
||||
|
||||
// TODO: check more ambiguous case
|
||||
// Group by ambiguous rule:
|
||||
// select c1 as a, c2 as a from t group by a is ambiguous
|
||||
// select c1 as a, c2 as a from t group by a + 1 is ambiguous
|
||||
// select c1 as c2, c2 from t group by c2 is ambiguous
|
||||
// select c1 as c2, c2 from t group by c2 + 1 is ambiguous
|
||||
|
||||
// TODO: use visitor to check aggregate function
|
||||
names := expression.MentionedColumns(e)
|
||||
for _, name := range names {
|
||||
indices := field.GetFieldIndex(name, fields[0:r.SelectList.HiddenFieldOffset], field.DefaultFieldFlag)
|
||||
if len(indices) == 1 {
|
||||
// check reference to aggregate function, like `select c1, count(c1) as b from t group by b + 1`.
|
||||
index := indices[0]
|
||||
if _, ok := aggFields[index]; ok {
|
||||
return nil, errors.Errorf("Reference '%s' not supported (reference to group function)", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// group by should be an expression, a qualified field name or a select field position,
|
||||
// but can not contain any aggregate function.
|
||||
if e := r.By[i]; expression.ContainAggregateFunc(e) {
|
||||
return nil, errors.Errorf("group by cannot contain aggregate function %s", e.String())
|
||||
}
|
||||
pos, err := castPosition(e, r.SelectList, groupByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if pos != nil {
|
||||
// use Position expression for the associated field.
|
||||
r.By[i] = pos
|
||||
continue
|
||||
}
|
||||
|
||||
visitor.rootIdent = castIdent(e)
|
||||
by, err := e.Accept(visitor)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
r.By[i] = by
|
||||
}
|
||||
|
||||
return &plans.GroupByDefaultPlan{By: r.By, Src: r.Src,
|
||||
|
||||
@ -50,6 +50,11 @@ func (s *testGroupByRsetSuite) SetUpSuite(c *C) {
|
||||
s.r = &GroupByRset{Src: tblPlan, SelectList: selectList, By: by}
|
||||
}
|
||||
|
||||
func resetAggFields(selectList *plans.SelectList) {
|
||||
fields := selectList.Fields
|
||||
selectList.AggFields = GetAggFields(fields)
|
||||
}
|
||||
|
||||
func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) {
|
||||
// `select id, name from t group by name`
|
||||
p, err := s.r.Plan(nil)
|
||||
@ -88,6 +93,8 @@ func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) {
|
||||
fld := &field.Field{Expr: fldExpr, AsName: "a"}
|
||||
s.r.SelectList.Fields[0] = fld
|
||||
|
||||
resetAggFields(s.r.SelectList)
|
||||
|
||||
s.r.By[0] = expression.Value{Val: int64(1)}
|
||||
|
||||
_, err = s.r.Plan(nil)
|
||||
@ -115,6 +122,8 @@ func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) {
|
||||
s.r.SelectList.ResultFields[0].Col.Name = model.NewCIStr("count(id)")
|
||||
s.r.SelectList.ResultFields[0].Name = "a"
|
||||
|
||||
resetAggFields(s.r.SelectList)
|
||||
|
||||
_, err = s.r.Plan(nil)
|
||||
c.Assert(err, NotNil)
|
||||
|
||||
|
||||
@ -31,66 +31,274 @@ type HavingRset struct {
|
||||
Src plan.Plan
|
||||
Expr expression.Expression
|
||||
SelectList *plans.SelectList
|
||||
// Group by clause
|
||||
GroupBy []expression.Expression
|
||||
}
|
||||
|
||||
// CheckAndUpdateSelectList checks having fields validity and set hidden fields to selectList.
|
||||
func (r *HavingRset) CheckAndUpdateSelectList(selectList *plans.SelectList, groupBy []expression.Expression, tableFields []*field.ResultField) error {
|
||||
// CheckAggregate will check whether order by has aggregate function or not,
|
||||
// if has, we will add it to select list hidden field.
|
||||
func (r *HavingRset) CheckAggregate(selectList *plans.SelectList) error {
|
||||
if expression.ContainAggregateFunc(r.Expr) {
|
||||
expr, err := selectList.UpdateAggFields(r.Expr, tableFields)
|
||||
expr, err := selectList.UpdateAggFields(r.Expr)
|
||||
if err != nil {
|
||||
return errors.Errorf("%s in 'having clause'", err.Error())
|
||||
}
|
||||
|
||||
r.Expr = expr
|
||||
} else {
|
||||
// having can only contain group by column and select list, e.g,
|
||||
// `select c1 from t group by c2 having c3 > 0` is invalid,
|
||||
// because c3 is not in group by and select list.
|
||||
names := expression.MentionedColumns(r.Expr)
|
||||
for _, name := range names {
|
||||
found := false
|
||||
|
||||
// check name whether in select list.
|
||||
// notice that `select c1 as c2 from t group by c1, c2, c3 having c2 > c3`,
|
||||
// will use t.c2 not t.c1 here.
|
||||
if field.ContainFieldName(name, selectList.ResultFields, field.OrgFieldNameFlag) {
|
||||
continue
|
||||
}
|
||||
if field.ContainFieldName(name, selectList.ResultFields, field.FieldNameFlag) {
|
||||
if field.ContainFieldName(name, tableFields, field.OrgFieldNameFlag) {
|
||||
selectList.CloneHiddenField(name, tableFields)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// check name whether in group by.
|
||||
// group by must only have column name, e.g,
|
||||
// `select c1 from t group by c2 having c2 > 0` is valid,
|
||||
// but `select c1 from t group by c2 + 1 having c2 > 0` is invalid.
|
||||
for _, by := range groupBy {
|
||||
if !field.CheckFieldsEqual(name, by.String()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// if name is not in table fields, it will get an unknown field error in GroupByRset,
|
||||
// so no need to check return value.
|
||||
selectList.CloneHiddenField(name, tableFields)
|
||||
|
||||
found = true
|
||||
break
|
||||
}
|
||||
|
||||
if !found {
|
||||
return errors.Errorf("Unknown column '%s' in 'having clause'", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type havingVisitor struct {
|
||||
expression.BaseVisitor
|
||||
selectList *plans.SelectList
|
||||
|
||||
// for group by
|
||||
groupBy []expression.Expression
|
||||
|
||||
// true means we are visiting aggregate function arguments now.
|
||||
inAggregate bool
|
||||
}
|
||||
|
||||
func (v *havingVisitor) visitIdentInAggregate(i *expression.Ident) (expression.Expression, error) {
|
||||
// if we are visiting aggregate function arguments, the identifier first checks in from table,
|
||||
// then in select list, and outer query finally.
|
||||
|
||||
// find this identifier in FROM.
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
|
||||
if len(idx) > 0 {
|
||||
i.ReferScope = expression.IdentReferFromTable
|
||||
i.ReferIndex = idx[0]
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// check in select list.
|
||||
index, err := checkIdentAmbiguous(i, v.selectList, havingClause)
|
||||
if err != nil {
|
||||
return i, errors.Trace(err)
|
||||
}
|
||||
|
||||
if index >= 0 {
|
||||
// find in select list
|
||||
i.ReferScope = expression.IdentReferSelectList
|
||||
i.ReferIndex = index
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// TODO: check in out query
|
||||
// TODO: return unknown field error, but now just return directly.
|
||||
// Because this may reference outer query.
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (v *havingVisitor) checkIdentInGroupBy(i *expression.Ident) (*expression.Ident, bool, error) {
|
||||
for _, by := range v.groupBy {
|
||||
e := castIdent(by)
|
||||
if e == nil {
|
||||
// group by must be a identifier too
|
||||
continue
|
||||
}
|
||||
|
||||
if !field.CheckFieldsEqual(e.L, i.L) {
|
||||
// not same, continue
|
||||
continue
|
||||
}
|
||||
|
||||
// if group by references select list or having is qualified identifier,
|
||||
// no other check.
|
||||
if e.ReferScope == expression.IdentReferSelectList || field.IsQualifiedName(i.L) {
|
||||
i.ReferScope = e.ReferScope
|
||||
i.ReferIndex = e.ReferIndex
|
||||
return i, true, nil
|
||||
}
|
||||
|
||||
// having is unqualified name, e.g, select * from t1, t2 group by t1.c having c.
|
||||
// both t1 and t2 have column c, we must check ambiguous here.
|
||||
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
|
||||
if len(idx) > 1 {
|
||||
return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i)
|
||||
}
|
||||
|
||||
i.ReferScope = e.ReferScope
|
||||
i.ReferIndex = e.ReferIndex
|
||||
return i, true, nil
|
||||
}
|
||||
|
||||
return i, false, nil
|
||||
}
|
||||
|
||||
func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression.Ident, bool, error) {
|
||||
index, err := checkIdentAmbiguous(i, v.selectList, havingClause)
|
||||
if err != nil {
|
||||
return i, false, errors.Trace(err)
|
||||
}
|
||||
|
||||
if index >= 0 {
|
||||
// identifier references a select field. use it directly.
|
||||
// e,g. select c1 as c2 from t having c2, here c2 references c1.
|
||||
i.ReferScope = expression.IdentReferSelectList
|
||||
i.ReferIndex = index
|
||||
return i, true, nil
|
||||
}
|
||||
|
||||
lastIndex := -1
|
||||
var lastFieldName string
|
||||
// we may meet this select c1 as c2 from t having c1, so we must check origin field name too.
|
||||
for index := 0; index < v.selectList.HiddenFieldOffset; index++ {
|
||||
e := castIdent(v.selectList.Fields[index].Expr)
|
||||
if e == nil {
|
||||
// not identifier
|
||||
continue
|
||||
}
|
||||
|
||||
if !field.CheckFieldsEqual(e.L, i.L) {
|
||||
// not same, continue
|
||||
continue
|
||||
}
|
||||
|
||||
if field.IsQualifiedName(i.L) {
|
||||
// qualified name, no need check
|
||||
i.ReferScope = expression.IdentReferSelectList
|
||||
i.ReferIndex = index
|
||||
return i, true, nil
|
||||
}
|
||||
|
||||
if lastIndex == -1 {
|
||||
lastIndex = index
|
||||
lastFieldName = e.L
|
||||
continue
|
||||
}
|
||||
|
||||
// we may meet select t1.c as a, t2.c as b from t1, t2 having c, must check ambiguous here
|
||||
if !field.CheckFieldsEqual(lastFieldName, e.L) {
|
||||
return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i)
|
||||
}
|
||||
|
||||
i.ReferScope = expression.IdentReferSelectList
|
||||
i.ReferIndex = index
|
||||
return i, true, nil
|
||||
}
|
||||
|
||||
if lastIndex != -1 {
|
||||
i.ReferScope = expression.IdentReferSelectList
|
||||
i.ReferIndex = lastIndex
|
||||
return i, true, nil
|
||||
}
|
||||
|
||||
return i, false, nil
|
||||
}
|
||||
|
||||
func (v *havingVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
|
||||
// Having has the most complex rule for ambiguous and identifer reference check.
|
||||
|
||||
// Having ambiguous rule:
|
||||
// select c1 as a, c2 as a from t having a is ambiguous
|
||||
// select c1 as a, c2 as a from t having a + 1 is ambiguous
|
||||
// select c1 as c2, c2 from t having c2 is ambiguous
|
||||
// select c1 as c2, c2 from t having c2 + 1 is ambiguous
|
||||
|
||||
// The identifier in having must exist in group by, select list or outer query, so select c1 from t having c2 is wrong,
|
||||
// the identifier will first try to find the reference in group by, then in select list and finally in outer query.
|
||||
|
||||
// Having identifier reference check
|
||||
// select c1 from t group by c2 having c2, c2 in having references group by c2.
|
||||
// select c1 as c2 from t having c2, c2 in having references c1 in select field.
|
||||
// select c1 as c2 from t group by c2 having c2, c2 in having references group by c2.
|
||||
// select c1 as c2 from t having sum(c2), c2 in having sum(c2) references t.c2 in from table.
|
||||
// select c1 as c2 from t having sum(c2) + c2, c2 in having sum(c2) references t.c2 in from table, but another c2 references c1 in select list.
|
||||
// select c1 as c2 from t group by c2 having sum(c2) + c2, c2 in having sum(c2) references t.c2 in from table, another c2 references c2 in group by.
|
||||
// select c1 as a from t having a, a references c1 in select field.
|
||||
// select c1 as a from t having sum(a), a in order by sum(a) references c1 in select field.
|
||||
// select c1 as c2 from t having c1, c1 in having references c1 in select list.
|
||||
|
||||
// TODO: unify VisitIdent for group by and order by visitor, they are little different.
|
||||
|
||||
if v.inAggregate {
|
||||
return v.visitIdentInAggregate(i)
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
ok bool
|
||||
)
|
||||
|
||||
// first, check in group by
|
||||
i, ok, err = v.checkIdentInGroupBy(i)
|
||||
if err != nil || ok {
|
||||
return i, errors.Trace(err)
|
||||
}
|
||||
|
||||
// then in select list
|
||||
i, ok, err = v.checkIdentInSelectList(i)
|
||||
if err != nil || ok {
|
||||
return i, errors.Trace(err)
|
||||
}
|
||||
|
||||
// TODO: check in out query
|
||||
|
||||
return i, errors.Errorf("Unknown column '%s' in 'having clause'", i)
|
||||
}
|
||||
|
||||
func (v *havingVisitor) VisitPosition(p *expression.Position) (expression.Expression, error) {
|
||||
n := p.N
|
||||
|
||||
// we have added order by to hidden field before, now visit it here.
|
||||
expr := v.selectList.Fields[n-1].Expr
|
||||
e, err := expr.Accept(v)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
v.selectList.Fields[n-1].Expr = e
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (v *havingVisitor) VisitCall(c *expression.Call) (expression.Expression, error) {
|
||||
isAggregate := expression.IsAggregateFunc(c.F)
|
||||
|
||||
if v.inAggregate && isAggregate {
|
||||
// aggregate function can't contain aggregate function
|
||||
return nil, errors.Errorf("Invalid use of group function")
|
||||
}
|
||||
|
||||
if isAggregate {
|
||||
// set true to let outer know we are in aggregate function.
|
||||
v.inAggregate = true
|
||||
}
|
||||
|
||||
var err error
|
||||
for i, e := range c.Args {
|
||||
c.Args[i], err = e.Accept(v)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
if isAggregate {
|
||||
v.inAggregate = false
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Plan gets HavingPlan.
|
||||
func (r *HavingRset) Plan(ctx context.Context) (plan.Plan, error) {
|
||||
visitor := &havingVisitor{
|
||||
selectList: r.SelectList,
|
||||
groupBy: r.GroupBy,
|
||||
inAggregate: false,
|
||||
}
|
||||
visitor.BaseVisitor.V = visitor
|
||||
|
||||
e, err := r.Expr.Accept(visitor)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
r.Expr = e
|
||||
|
||||
return &plans.HavingPlan{Src: r.Src, Expr: r.Expr, SelectList: r.SelectList}, nil
|
||||
}
|
||||
|
||||
|
||||
@ -50,59 +50,6 @@ func (s *testHavingRsetSuite) SetUpSuite(c *C) {
|
||||
s.r = &HavingRset{Src: tblPlan, Expr: expr, SelectList: selectList}
|
||||
}
|
||||
|
||||
func (s *testHavingRsetSuite) TestHavingRsetCheckAndUpdateSelectList(c *C) {
|
||||
resultFields := s.r.Src.GetFields()
|
||||
|
||||
selectList := s.r.SelectList
|
||||
|
||||
groupBy := []expression.Expression{}
|
||||
|
||||
// `select id, name from t having id > 1`
|
||||
err := s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// `select name from t group by id having id > 1`
|
||||
selectList.ResultFields = selectList.ResultFields[1:]
|
||||
selectList.Fields = selectList.Fields[1:]
|
||||
|
||||
groupBy = []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("id")}}
|
||||
err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// `select name from t group by id + 1 having id > 1`
|
||||
expr := expression.NewBinaryOperation(opcode.Plus, &expression.Ident{CIStr: model.NewCIStr("id")}, expression.Value{Val: 1})
|
||||
|
||||
groupBy = []expression.Expression{expr}
|
||||
err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// `select name from t group by id + 1 having count(1) > 1`
|
||||
aggExpr, err := expression.NewCall("count", []expression.Expression{expression.Value{Val: 1}}, false)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
s.r.Expr = aggExpr
|
||||
|
||||
err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// `select name from t group by id + 1 having count(xxx) > 1`
|
||||
aggExpr, err = expression.NewCall("count", []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("xxx")}}, false)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
s.r.Expr = aggExpr
|
||||
|
||||
err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
|
||||
c.Assert(err, NotNil)
|
||||
|
||||
// `select name from t group by id having xxx > 1`
|
||||
expr = expression.NewBinaryOperation(opcode.GT, &expression.Ident{CIStr: model.NewCIStr("xxx")}, expression.Value{Val: 1})
|
||||
|
||||
s.r.Expr = expr
|
||||
|
||||
err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
|
||||
c.Assert(err, NotNil)
|
||||
}
|
||||
|
||||
func (s *testHavingRsetSuite) TestHavingRsetPlan(c *C) {
|
||||
p, err := s.r.Plan(nil)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
@ -14,13 +14,15 @@
|
||||
package rsets
|
||||
|
||||
import (
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/plan/plans"
|
||||
)
|
||||
|
||||
// GetAggFields gets aggregate fields position map.
|
||||
func GetAggFields(fields []*field.Field) map[int]struct{} {
|
||||
aggFields := map[int]struct{}{}
|
||||
aggFields := make(map[int]struct{}, len(fields))
|
||||
for i, v := range fields {
|
||||
if expression.ContainAggregateFunc(v.Expr) {
|
||||
aggFields[i] = struct{}{}
|
||||
@ -34,3 +36,122 @@ func HasAggFields(fields []*field.Field) bool {
|
||||
aggFields := GetAggFields(fields)
|
||||
return len(aggFields) > 0
|
||||
}
|
||||
|
||||
// castIdent returns an Ident expression if e is or nil.
|
||||
func castIdent(e expression.Expression) *expression.Ident {
|
||||
i, ok := e.(*expression.Ident)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// TODO: export clause type and move to plan?
|
||||
type clauseType int
|
||||
|
||||
const (
|
||||
noneClause clauseType = iota
|
||||
onClause
|
||||
whereClause
|
||||
groupByClause
|
||||
fieldListClause
|
||||
havingClause
|
||||
orderByClause
|
||||
)
|
||||
|
||||
func (clause clauseType) String() string {
|
||||
switch clause {
|
||||
case onClause:
|
||||
return "on clause"
|
||||
case fieldListClause:
|
||||
return "field list"
|
||||
case whereClause:
|
||||
return "where clause"
|
||||
case groupByClause:
|
||||
return "group statement"
|
||||
case orderByClause:
|
||||
return "order clause"
|
||||
case havingClause:
|
||||
return "having clause"
|
||||
}
|
||||
return "none"
|
||||
}
|
||||
|
||||
// castPosition returns an group/order by Position expression if e is a number.
|
||||
func castPosition(e expression.Expression, selectList *plans.SelectList, clause clauseType) (*expression.Position, error) {
|
||||
v, ok := e.(expression.Value)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var position int
|
||||
switch u := v.Val.(type) {
|
||||
case int64:
|
||||
position = int(u)
|
||||
case uint64:
|
||||
position = int(u)
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if position < 1 || position > selectList.HiddenFieldOffset {
|
||||
return nil, errors.Errorf("Unknown column '%d' in '%s'", position, clause)
|
||||
}
|
||||
|
||||
if clause == groupByClause {
|
||||
index := position - 1
|
||||
if _, ok := selectList.AggFields[index]; ok {
|
||||
return nil, errors.Errorf("Can't group on '%s'", selectList.Fields[index])
|
||||
}
|
||||
}
|
||||
|
||||
// use Position expression for the associated field.
|
||||
return &expression.Position{N: position}, nil
|
||||
}
|
||||
|
||||
func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clause clauseType) (int, error) {
|
||||
index, err := selectList.CheckAmbiguous(i)
|
||||
if err != nil {
|
||||
return -1, errors.Errorf("Column '%s' in %s is ambiguous", i, clause)
|
||||
} else if index == -1 {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// fromIdentVisitor can only handle identifier which reference FROM table or outer query.
|
||||
// like in common select list, where or join on condition.
|
||||
type fromIdentVisitor struct {
|
||||
expression.BaseVisitor
|
||||
fromFields []*field.ResultField
|
||||
clause clauseType
|
||||
}
|
||||
|
||||
func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
|
||||
idx := field.GetResultFieldIndex(i.L, v.fromFields, field.DefaultFieldFlag)
|
||||
if len(idx) == 1 {
|
||||
i.ReferScope = expression.IdentReferFromTable
|
||||
i.ReferIndex = idx[0]
|
||||
return i, nil
|
||||
} else if len(idx) > 1 {
|
||||
return nil, errors.Errorf("Column '%s' in %s is ambiguous", i, v.clause)
|
||||
}
|
||||
|
||||
if v.clause == onClause {
|
||||
// on clause can't check outer query.
|
||||
return nil, errors.Errorf("Unknown column '%s' in '%s'", i, v.clause)
|
||||
}
|
||||
|
||||
// TODO: check in outer query
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func newFromIdentVisitor(fromFields []*field.ResultField, clause clauseType) *fromIdentVisitor {
|
||||
visitor := &fromIdentVisitor{}
|
||||
visitor.BaseVisitor.V = visitor
|
||||
visitor.fromFields = fromFields
|
||||
visitor.clause = clause
|
||||
|
||||
return visitor
|
||||
}
|
||||
|
||||
@ -141,6 +141,17 @@ func (r *JoinRset) buildJoinPlan(ctx context.Context, p *plans.JoinPlan, s *Join
|
||||
p.Fields = append(p.Fields, rightFields...)
|
||||
}
|
||||
|
||||
if p.On != nil {
|
||||
visitor := newFromIdentVisitor(p.Fields, onClause)
|
||||
|
||||
e, err := p.On.Accept(visitor)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
p.On = e
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -65,46 +65,110 @@ func (r *OrderByRset) String() string {
|
||||
return strings.Join(a, ", ")
|
||||
}
|
||||
|
||||
// CheckAndUpdateSelectList checks order by fields validity and set hidden fields to selectList.
|
||||
func (r *OrderByRset) CheckAndUpdateSelectList(selectList *plans.SelectList, tableFields []*field.ResultField) error {
|
||||
// CheckAggregate will check whether order by has aggregate function or not,
|
||||
// if has, we will add it to select list hidden field.
|
||||
func (r *OrderByRset) CheckAggregate(selectList *plans.SelectList) error {
|
||||
for i, v := range r.By {
|
||||
if expression.ContainAggregateFunc(v.Expr) {
|
||||
expr, err := selectList.UpdateAggFields(v.Expr, tableFields)
|
||||
expr, err := selectList.UpdateAggFields(v.Expr)
|
||||
if err != nil {
|
||||
return errors.Errorf("%s in 'order clause'", err.Error())
|
||||
}
|
||||
|
||||
r.By[i].Expr = expr
|
||||
} else {
|
||||
if _, err := selectList.CheckReferAmbiguous(v.Expr); err != nil {
|
||||
return errors.Errorf("Column '%s' in order statement is ambiguous", v.Expr)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: check more ambiguous case
|
||||
// Order by ambiguous rule:
|
||||
// select c1 as a, c2 as a from t order by a is ambiguous
|
||||
// select c1 as a, c2 as a from t order by a + 1 is ambiguous
|
||||
// select c1 as c2, c2 from t order by c2 is ambiguous
|
||||
// select c1 as c2, c2 from t order by c2 + 1 is ambiguous
|
||||
type orderByVisitor struct {
|
||||
expression.BaseVisitor
|
||||
selectList *plans.SelectList
|
||||
rootIdent *expression.Ident
|
||||
}
|
||||
|
||||
// TODO: use vistor to refactor all and combine following plan check.
|
||||
names := expression.MentionedColumns(v.Expr)
|
||||
for _, name := range names {
|
||||
// try to find in select list
|
||||
// TODO: mysql has confused result for this, see #555.
|
||||
// now we use select list then order by, later we should make it easier.
|
||||
if field.ContainFieldName(name, selectList.ResultFields, field.CheckFieldFlag) {
|
||||
continue
|
||||
}
|
||||
func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
|
||||
// Order by ambiguous rule:
|
||||
// select c1 as a, c2 as a from t order by a is ambiguous
|
||||
// select c1 as a, c2 as a from t order by a + 1 is ambiguous
|
||||
// select c1 as c2, c2 from t order by c2 is ambiguous
|
||||
// select c1 as c2, c2 from t order by c2 + 1 is not ambiguous
|
||||
|
||||
if !selectList.CloneHiddenField(name, tableFields) {
|
||||
return errors.Errorf("Unknown column '%s' in 'order clause'", name)
|
||||
}
|
||||
}
|
||||
// Order by identifier reference check
|
||||
// select c1 as c2 from t order by c2, c2 in order by references c1 in select field.
|
||||
// select c1 as c2 from t order by c2 + 1, c2 in order by c2 + 1 references t.c2 in from table.
|
||||
// select c1 as c2 from t order by sum(c2), c2 in order by sum(c2) references t.c2 in from table.
|
||||
// select c1 as c2 from t order by sum(1) + c2, c2 in order by sum(1) + c2 references t.c2 in from table.
|
||||
// select c1 as a from t order by a, a references c1 in select field.
|
||||
// select c1 as a from t order by sum(a), a in order by sum(a) references c1 in select field.
|
||||
|
||||
// TODO: unify VisitIdent for group by and order by visitor, they are little different.
|
||||
|
||||
var (
|
||||
index int
|
||||
err error
|
||||
)
|
||||
|
||||
if v.rootIdent == i {
|
||||
// The order by is an identifier, we must check it first.
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, orderByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if index >= 0 {
|
||||
// identifier references a select field. use it directly.
|
||||
// e,g. select c1 as c2 from t order by c2, here c2 references c1.
|
||||
i.ReferScope = expression.IdentReferSelectList
|
||||
i.ReferIndex = index
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
// find this identifier in FROM.
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
|
||||
if len(idx) > 0 {
|
||||
i.ReferScope = expression.IdentReferFromTable
|
||||
i.ReferIndex = idx[0]
|
||||
return i, nil
|
||||
}
|
||||
|
||||
if v.rootIdent != i {
|
||||
// This identifier is the part of the order by, check ambiguous here.
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, orderByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
// try to find in select list, we have got index using checkIdentAmbiguous before.
|
||||
if index >= 0 {
|
||||
// find in select list
|
||||
i.ReferScope = expression.IdentReferSelectList
|
||||
i.ReferIndex = index
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// TODO: check in out query
|
||||
return i, errors.Errorf("Unknown column '%s' in 'order clause'", i)
|
||||
}
|
||||
|
||||
func (v *orderByVisitor) VisitPosition(p *expression.Position) (expression.Expression, error) {
|
||||
n := p.N
|
||||
if p.N <= v.selectList.HiddenFieldOffset {
|
||||
// this position expression is not in hidden field, no need to check
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// we have added order by to hidden field before, now visit it here.
|
||||
expr := v.selectList.Fields[n-1].Expr
|
||||
e, err := expr.Accept(v)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
v.selectList.Fields[n-1].Expr = e
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Plan get SrcPlan/OrderByDefaultPlan.
|
||||
@ -120,45 +184,28 @@ func (r *OrderByRset) Plan(ctx context.Context) (plan.Plan, error) {
|
||||
ascs []bool
|
||||
)
|
||||
|
||||
fields := r.Src.GetFields()
|
||||
visitor := &orderByVisitor{}
|
||||
visitor.BaseVisitor.V = visitor
|
||||
visitor.selectList = r.SelectList
|
||||
|
||||
for i := range r.By {
|
||||
e := r.By[i].Expr
|
||||
if v, ok := e.(expression.Value); ok {
|
||||
var (
|
||||
position int
|
||||
isPosition = true
|
||||
)
|
||||
|
||||
switch u := v.Val.(type) {
|
||||
case int64:
|
||||
position = int(u)
|
||||
case uint64:
|
||||
position = int(u)
|
||||
default:
|
||||
isPosition = false
|
||||
// only const value
|
||||
}
|
||||
pos, err := castPosition(e, r.SelectList, orderByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if isPosition {
|
||||
if position < 1 || position > len(fields) {
|
||||
return nil, errors.Errorf("Unknown column '%d' in 'order clause'", position)
|
||||
}
|
||||
|
||||
// use Position expression for the associated field.
|
||||
r.By[i].Expr = &expression.Position{N: position}
|
||||
}
|
||||
if pos != nil {
|
||||
// use Position expression for the associated field.
|
||||
r.By[i].Expr = pos
|
||||
} else {
|
||||
// Don't check ambiguous here, only check field exists or not.
|
||||
// TODO: use visitor to refactor.
|
||||
colNames := expression.MentionedColumns(e)
|
||||
for _, name := range colNames {
|
||||
if idx := field.GetResultFieldIndex(name, r.SelectList.ResultFields, field.DefaultFieldFlag); len(idx) == 0 {
|
||||
// find in from
|
||||
if idx = field.GetResultFieldIndex(name, r.SelectList.FromFields, field.DefaultFieldFlag); len(idx) == 0 {
|
||||
return nil, errors.Errorf("unknown field %s", name)
|
||||
}
|
||||
}
|
||||
visitor.rootIdent = castIdent(e)
|
||||
e, err = e.Accept(visitor)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
r.By[i].Expr = e
|
||||
}
|
||||
|
||||
by = append(by, r.By[i].Expr)
|
||||
|
||||
@ -47,64 +47,6 @@ func (s *testOrderByRsetSuite) SetUpSuite(c *C) {
|
||||
s.r = &OrderByRset{Src: tblPlan, SelectList: selectList}
|
||||
}
|
||||
|
||||
func (s *testOrderByRsetSuite) TestOrderByRsetCheckAndUpdateSelectList(c *C) {
|
||||
resultFields := s.r.Src.GetFields()
|
||||
|
||||
fields := make([]*field.Field, len(resultFields))
|
||||
for i, resultField := range resultFields {
|
||||
name := resultField.Name
|
||||
fields[i] = &field.Field{Expr: &expression.Ident{CIStr: model.NewCIStr(name)}}
|
||||
}
|
||||
|
||||
selectList := &plans.SelectList{
|
||||
HiddenFieldOffset: len(resultFields),
|
||||
ResultFields: resultFields,
|
||||
Fields: fields,
|
||||
}
|
||||
|
||||
expr := &expression.Ident{CIStr: model.NewCIStr("id")}
|
||||
orderByItem := OrderByItem{Expr: expr, Asc: true}
|
||||
by := []OrderByItem{orderByItem}
|
||||
r := &OrderByRset{By: by, SelectList: selectList}
|
||||
|
||||
// `select id, name from t order by id`
|
||||
err := r.CheckAndUpdateSelectList(selectList, resultFields)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// `select id, name as id from t order by id`
|
||||
selectList.Fields[1].AsName = "id"
|
||||
selectList.ResultFields[1].Name = "id"
|
||||
|
||||
err = r.CheckAndUpdateSelectList(selectList, resultFields)
|
||||
c.Assert(err, NotNil)
|
||||
|
||||
// `select id, name from t order by count(1) > 1`
|
||||
aggExpr, err := expression.NewCall("count", []expression.Expression{expression.Value{Val: 1}}, false)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
r.By[0].Expr = aggExpr
|
||||
|
||||
err = r.CheckAndUpdateSelectList(selectList, resultFields)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// `select id, name from t order by count(xxx) > 1`
|
||||
aggExpr, err = expression.NewCall("count", []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("xxx")}}, false)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
r.By[0].Expr = aggExpr
|
||||
|
||||
err = r.CheckAndUpdateSelectList(selectList, resultFields)
|
||||
c.Assert(err, NotNil)
|
||||
|
||||
// `select id, name from t order by xxx`
|
||||
r.By[0].Expr = &expression.Ident{CIStr: model.NewCIStr("xxx")}
|
||||
selectList.Fields[1].AsName = "name"
|
||||
selectList.ResultFields[1].Name = "name"
|
||||
|
||||
err = r.CheckAndUpdateSelectList(selectList, resultFields)
|
||||
c.Assert(err, NotNil)
|
||||
}
|
||||
|
||||
func (s *testOrderByRsetSuite) TestOrderByRsetPlan(c *C) {
|
||||
// `select id, name from t`
|
||||
_, err := s.r.Plan(nil)
|
||||
@ -145,12 +87,6 @@ func (s *testOrderByRsetSuite) TestOrderByRsetPlan(c *C) {
|
||||
_, err = s.r.Plan(nil)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// `select id, name from t order by xxx`
|
||||
s.r.By[0].Expr = &expression.Ident{CIStr: model.NewCIStr("xxx")}
|
||||
|
||||
_, err = s.r.Plan(nil)
|
||||
c.Assert(err, NotNil)
|
||||
|
||||
// check src plan is NullPlan
|
||||
s.r.Src = &plans.NullPlan{Fields: s.r.SelectList.ResultFields}
|
||||
|
||||
|
||||
@ -162,7 +162,6 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if val == nil {
|
||||
// like `select * from t where null`.
|
||||
return &plans.NullPlan{Fields: r.Src.GetFields()}, nil
|
||||
@ -172,7 +171,6 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
// like `select * from t where 0`.
|
||||
return &plans.NullPlan{Fields: r.Src.GetFields()}, nil
|
||||
@ -181,8 +179,25 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl
|
||||
return &plans.FilterDefaultPlan{Plan: r.Src, Expr: e}, nil
|
||||
}
|
||||
|
||||
func (r *WhereRset) updateWhereFieldsRefer() error {
|
||||
visitor := newFromIdentVisitor(r.Src.GetFields(), whereClause)
|
||||
|
||||
e, err := r.Expr.Accept(visitor)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
r.Expr = e
|
||||
return nil
|
||||
}
|
||||
|
||||
// Plan gets NullPlan/FilterDefaultPlan.
|
||||
func (r *WhereRset) Plan(ctx context.Context) (plan.Plan, error) {
|
||||
// Update where fields refer expr by using fromIdentVisitor.
|
||||
if err := r.updateWhereFieldsRefer(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
expr := r.Expr.Clone()
|
||||
if expr.IsStatic() {
|
||||
// IsStaic means we have a const value for where condition, and we don't need any index.
|
||||
|
||||
@ -31,7 +31,7 @@ import (
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/kv"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/rset"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
"github.com/pingcap/tidb/sessionctx/db"
|
||||
|
||||
@ -15,7 +15,7 @@ package variable
|
||||
|
||||
import (
|
||||
"github.com/pingcap/tidb/context"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
// SessionVars is to handle user-defined or global variables in current session.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user