support set variable, prepare and insert stmt. (#1359)
This commit is contained in:
@ -21,6 +21,7 @@ import (
|
||||
"github.com/pingcap/tidb/meta"
|
||||
"github.com/pingcap/tidb/model"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/plan"
|
||||
"github.com/pingcap/tidb/plan/statistics"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
@ -30,6 +31,7 @@ import (
|
||||
)
|
||||
|
||||
func (s *testSuite) TestCharsetDatabase(c *C) {
|
||||
plan.UseNewPlanner = true
|
||||
defer testleak.AfterTest(c)()
|
||||
tk := testkit.NewTestKit(c, s.store)
|
||||
testSQL := `create database if not exists cd_test_utf8 CHARACTER SET utf8 COLLATE utf8_bin;`
|
||||
@ -47,9 +49,11 @@ func (s *testSuite) TestCharsetDatabase(c *C) {
|
||||
tk.MustExec(testSQL)
|
||||
tk.MustQuery(`select @@character_set_database;`).Check(testkit.Rows("latin1"))
|
||||
tk.MustQuery(`select @@collation_database;`).Check(testkit.Rows("latin1_swedish_ci"))
|
||||
plan.UseNewPlanner = false
|
||||
}
|
||||
|
||||
func (s *testSuite) TestSetVar(c *C) {
|
||||
plan.UseNewPlanner = true
|
||||
defer testleak.AfterTest(c)()
|
||||
tk := testkit.NewTestKit(c, s.store)
|
||||
testSQL := "SET @a = 1;"
|
||||
@ -106,6 +110,7 @@ func (s *testSuite) TestSetVar(c *C) {
|
||||
testSQL = "SET @@global.autocommit=1, @issue998b=6;"
|
||||
tk.MustExec(testSQL)
|
||||
tk.MustQuery(`select @issue998b, @@global.autocommit;`).Check(testkit.Rows("6 1"))
|
||||
plan.UseNewPlanner = false
|
||||
}
|
||||
|
||||
func (s *testSuite) TestSetCharset(c *C) {
|
||||
|
||||
@ -127,6 +127,7 @@ func (s *testSuite) TestAdmin(c *C) {
|
||||
}
|
||||
|
||||
func (s *testSuite) TestPrepared(c *C) {
|
||||
plan.UseNewPlanner = true
|
||||
defer testleak.AfterTest(c)()
|
||||
tk := testkit.NewTestKit(c, s.store)
|
||||
tk.MustExec("use test")
|
||||
@ -182,6 +183,7 @@ func (s *testSuite) TestPrepared(c *C) {
|
||||
exec.Fields()
|
||||
exec.Next()
|
||||
exec.Close()
|
||||
plan.UseNewPlanner = false
|
||||
}
|
||||
|
||||
func (s *testSuite) fillData(tk *testkit.TestKit, table string) {
|
||||
@ -591,6 +593,7 @@ func (s *testSuite) TestSelectWithoutFrom(c *C) {
|
||||
}
|
||||
|
||||
func (s *testSuite) TestSelectLimit(c *C) {
|
||||
plan.UseNewPlanner = true
|
||||
defer testleak.AfterTest(c)()
|
||||
tk := testkit.NewTestKit(c, s.store)
|
||||
tk.MustExec("use test")
|
||||
@ -629,6 +632,7 @@ func (s *testSuite) TestSelectLimit(c *C) {
|
||||
_, err := tk.Exec("select * from select_limit limit 18446744073709551616 offset 3;")
|
||||
c.Assert(err, NotNil)
|
||||
tk.MustExec("rollback")
|
||||
plan.UseNewPlanner = false
|
||||
}
|
||||
|
||||
func (s *testSuite) TestSelectOrderBy(c *C) {
|
||||
@ -732,6 +736,7 @@ func (s *testSuite) TestSelectDistinct(c *C) {
|
||||
}
|
||||
|
||||
func (s *testSuite) TestSelectErrorRow(c *C) {
|
||||
plan.UseNewPlanner = true
|
||||
defer testleak.AfterTest(c)()
|
||||
tk := testkit.NewTestKit(c, s.store)
|
||||
tk.MustExec("use test")
|
||||
@ -762,6 +767,7 @@ func (s *testSuite) TestSelectErrorRow(c *C) {
|
||||
c.Assert(err, NotNil)
|
||||
|
||||
tk.MustExec("commit")
|
||||
plan.UseNewPlanner = false
|
||||
}
|
||||
|
||||
func (s *testSuite) TestUpdate(c *C) {
|
||||
@ -1220,6 +1226,7 @@ func (s *testSuite) TestIndexReverseOrder(c *C) {
|
||||
}
|
||||
|
||||
func (s *testSuite) TestTableReverseOrder(c *C) {
|
||||
plan.UseNewPlanner = true
|
||||
defer testleak.AfterTest(c)()
|
||||
tk := testkit.NewTestKit(c, s.store)
|
||||
tk.MustExec("use test")
|
||||
@ -1230,6 +1237,7 @@ func (s *testSuite) TestTableReverseOrder(c *C) {
|
||||
result.Check(testkit.Rows("9", "8", "7", "6", "5", "4", "3", "2", "1"))
|
||||
result = tk.MustQuery("select a from t where a <3 or (a >=6 and a < 8) order by a desc")
|
||||
result.Check(testkit.Rows("7", "6", "2", "1"))
|
||||
plan.UseNewPlanner = false
|
||||
}
|
||||
|
||||
func (s *testSuite) TestInSubquery(c *C) {
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/plan"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
"github.com/pingcap/tidb/table"
|
||||
@ -650,7 +651,7 @@ func (e *InsertValues) getRow(cols []*table.Column, list []ast.ExprNode, default
|
||||
|
||||
func (e *InsertValues) getRowsSelect(cols []*table.Column) ([][]types.Datum, error) {
|
||||
// process `insert|replace into ... select ... from ...`
|
||||
if len(e.SelectExec.Fields()) != len(cols) {
|
||||
if (!plan.UseNewPlanner && len(e.SelectExec.Fields()) != len(cols)) || (plan.UseNewPlanner && len(e.SelectExec.Schema()) != len(cols)) {
|
||||
return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(e.SelectExec.Fields()))
|
||||
}
|
||||
var rows [][]types.Datum
|
||||
|
||||
@ -231,9 +231,10 @@ func NewFunction(funcName string, retType *types.FieldType, args ...Expression)
|
||||
log.Errorf("Function %s is not implemented.", funcName)
|
||||
return nil
|
||||
}
|
||||
|
||||
funcArgs := make([]Expression, len(args))
|
||||
copy(funcArgs, args)
|
||||
return &ScalarFunction{
|
||||
Args: args,
|
||||
Args: funcArgs,
|
||||
FuncName: model.NewCIStr(funcName),
|
||||
RetType: retType,
|
||||
Function: f.F}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package plan
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/context"
|
||||
@ -9,6 +11,7 @@ import (
|
||||
"github.com/pingcap/tidb/infoschema"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
@ -46,8 +49,8 @@ func getRowLen(e expression.Expression) int {
|
||||
if f, ok := e.(*expression.ScalarFunction); ok && f.FuncName.L == ast.RowFunc {
|
||||
return len(f.Args)
|
||||
}
|
||||
if f, ok := e.(*expression.Constant); ok && f.RetType.Tp == types.KindRow {
|
||||
return len(f.Value.GetRow())
|
||||
if c, ok := e.(*expression.Constant); ok && c.Value.Kind() == types.KindRow {
|
||||
return len(c.Value.GetRow())
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@ -231,7 +234,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
|
||||
for _, col := range np.GetSchema() {
|
||||
args = append(args, col.DeepCopy())
|
||||
}
|
||||
rexpr = expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), args...)
|
||||
rexpr = expression.NewFunction(ast.RowFunc, nil, args...)
|
||||
}
|
||||
// a in (subq) will be rewrited as a = any(subq).
|
||||
// a not in (subq) will be rewrited as a != all(subq).
|
||||
@ -263,7 +266,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
|
||||
newCols = append(newCols, col.DeepCopy())
|
||||
}
|
||||
er.ctxStack = append(er.ctxStack,
|
||||
expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), newCols...))
|
||||
expression.NewFunction(ast.RowFunc, nil, newCols...))
|
||||
} else {
|
||||
er.ctxStack = append(er.ctxStack, np.GetSchema()[0])
|
||||
}
|
||||
@ -287,7 +290,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
|
||||
RetType: np.GetSchema()[i].GetType()})
|
||||
}
|
||||
er.ctxStack = append(er.ctxStack,
|
||||
expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), newCols...))
|
||||
expression.NewFunction(ast.RowFunc, nil, newCols...))
|
||||
} else {
|
||||
er.ctxStack = append(er.ctxStack, np.GetSchema()[0])
|
||||
}
|
||||
@ -312,8 +315,9 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool)
|
||||
rows = append(rows, er.ctxStack[i])
|
||||
}
|
||||
er.ctxStack = er.ctxStack[:stkLen-length]
|
||||
er.ctxStack = append(er.ctxStack, expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), rows...))
|
||||
|
||||
er.ctxStack = append(er.ctxStack, expression.NewFunction(ast.RowFunc, nil, rows...))
|
||||
case *ast.VariableExpr:
|
||||
return inNode, er.rewriteVariable(v)
|
||||
case *ast.FuncCallExpr:
|
||||
er.funcCallToScalarFunc(v)
|
||||
case *ast.PositionExpr:
|
||||
@ -326,7 +330,10 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool)
|
||||
er.toColumn(v)
|
||||
case *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause, *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr:
|
||||
case *ast.ValueExpr:
|
||||
value := &expression.Constant{Value: v.Datum, RetType: types.NewFieldType(v.Datum.Kind())}
|
||||
value := &expression.Constant{Value: v.Datum, RetType: v.Type}
|
||||
er.ctxStack = append(er.ctxStack, value)
|
||||
case *ast.ParamMarkerExpr:
|
||||
value := &expression.Constant{Value: v.Datum, RetType: v.Type}
|
||||
er.ctxStack = append(er.ctxStack, value)
|
||||
case *ast.IsNullExpr:
|
||||
if getRowLen(er.ctxStack[stkLen-1]) != 1 {
|
||||
@ -356,7 +363,7 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool)
|
||||
return retNode, false
|
||||
}
|
||||
default:
|
||||
function = expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-2], er.ctxStack[stkLen-1])
|
||||
function = expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-2:]...)
|
||||
}
|
||||
er.ctxStack = er.ctxStack[:stkLen-2]
|
||||
er.ctxStack = append(er.ctxStack, function)
|
||||
@ -387,6 +394,83 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool)
|
||||
return inNode, true
|
||||
}
|
||||
|
||||
func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) bool {
|
||||
stkLen := len(er.ctxStack)
|
||||
name := strings.ToLower(v.Name)
|
||||
sessionVars := variable.GetSessionVars(er.b.ctx)
|
||||
globalVars := variable.GetGlobalVarAccessor(er.b.ctx)
|
||||
if !v.IsSystem {
|
||||
var d types.Datum
|
||||
var err error
|
||||
if v.Value != nil {
|
||||
d, err = er.ctxStack[stkLen-1].Eval(nil, er.b.ctx)
|
||||
if err != nil {
|
||||
er.err = errors.Trace(err)
|
||||
return false
|
||||
}
|
||||
er.ctxStack = er.ctxStack[:stkLen-1]
|
||||
}
|
||||
if !d.IsNull() {
|
||||
|
||||
strVal, err := d.ToString()
|
||||
if err != nil {
|
||||
er.err = errors.Trace(err)
|
||||
return false
|
||||
}
|
||||
sessionVars.Users[name] = strings.ToLower(strVal)
|
||||
er.ctxStack = append(er.ctxStack, &expression.Constant{Value: d, RetType: types.NewFieldType(mysql.TypeString)})
|
||||
} else if value, ok := sessionVars.Users[name]; ok {
|
||||
er.ctxStack = append(er.ctxStack, &expression.Constant{Value: types.NewDatum(value), RetType: types.NewFieldType(mysql.TypeString)})
|
||||
} else {
|
||||
// select null user vars is permitted.
|
||||
er.ctxStack = append(er.ctxStack, &expression.Constant{RetType: types.NewFieldType(mysql.TypeNull)})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
sysVar, ok := variable.SysVars[name]
|
||||
if !ok {
|
||||
// select null sys vars is not permitted
|
||||
er.err = variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name)
|
||||
return false
|
||||
}
|
||||
if sysVar.Scope == variable.ScopeNone {
|
||||
er.ctxStack = append(er.ctxStack, &expression.Constant{Value: types.NewDatum(sysVar.Value), RetType: types.NewFieldType(mysql.TypeString)})
|
||||
return true
|
||||
}
|
||||
|
||||
if !v.IsGlobal {
|
||||
d := sessionVars.GetSystemVar(name)
|
||||
if d.IsNull() {
|
||||
if sysVar.Scope&variable.ScopeGlobal == 0 {
|
||||
d.SetString(sysVar.Value)
|
||||
} else {
|
||||
// Get global system variable and fill it in session.
|
||||
globalVal, err := globalVars.GetGlobalSysVar(er.b.ctx, name)
|
||||
if err != nil {
|
||||
er.err = errors.Trace(err)
|
||||
return false
|
||||
}
|
||||
d.SetString(globalVal)
|
||||
err = sessionVars.SetSystemVar(name, d)
|
||||
if err != nil {
|
||||
er.err = errors.Trace(err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
er.ctxStack = append(er.ctxStack, &expression.Constant{Value: d, RetType: types.NewFieldType(mysql.TypeString)})
|
||||
return true
|
||||
}
|
||||
value, err := globalVars.GetGlobalSysVar(er.b.ctx, name)
|
||||
if err != nil {
|
||||
er.err = errors.Trace(err)
|
||||
return false
|
||||
}
|
||||
er.ctxStack = append(er.ctxStack, &expression.Constant{Value: types.NewDatum(value), RetType: types.NewFieldType(mysql.TypeString)})
|
||||
return true
|
||||
}
|
||||
|
||||
func (er *expressionRewriter) notToScalarFunc(b bool, op string, tp *types.FieldType,
|
||||
args ...expression.Expression) *expression.ScalarFunction {
|
||||
opFunc := expression.NewFunction(op, tp, args...)
|
||||
|
||||
@ -325,16 +325,18 @@ func (b *planBuilder) buildNewUnion(union *ast.UnionStmt) Plan {
|
||||
}
|
||||
|
||||
u.SetSchema(firstSchema)
|
||||
var p Plan
|
||||
p = u
|
||||
if union.Distinct {
|
||||
return b.buildNewDistinct(u)
|
||||
p = b.buildNewDistinct(u)
|
||||
}
|
||||
if union.OrderBy != nil {
|
||||
return b.buildNewSort(u, union.OrderBy.Items, nil)
|
||||
p = b.buildNewSort(p, union.OrderBy.Items, nil)
|
||||
}
|
||||
if union.Limit != nil {
|
||||
return b.buildNewLimit(u, union.Limit)
|
||||
p = b.buildNewLimit(p, union.Limit)
|
||||
}
|
||||
return u
|
||||
return p
|
||||
}
|
||||
|
||||
// ByItems wraps a "by" item.
|
||||
|
||||
@ -104,6 +104,9 @@ func (b *planBuilder) build(node ast.Node) Plan {
|
||||
}
|
||||
return b.buildSelect(x)
|
||||
case *ast.UnionStmt:
|
||||
if UseNewPlanner {
|
||||
return b.buildNewUnion(x)
|
||||
}
|
||||
return b.buildUnion(x)
|
||||
case *ast.UpdateStmt:
|
||||
return b.buildUpdate(x)
|
||||
|
||||
Reference in New Issue
Block a user