diff --git a/executor/executor_simple_test.go b/executor/executor_simple_test.go index 4ceae923c9..44300ba334 100644 --- a/executor/executor_simple_test.go +++ b/executor/executor_simple_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/plan/statistics" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" @@ -30,6 +31,7 @@ import ( ) func (s *testSuite) TestCharsetDatabase(c *C) { + plan.UseNewPlanner = true defer testleak.AfterTest(c)() tk := testkit.NewTestKit(c, s.store) testSQL := `create database if not exists cd_test_utf8 CHARACTER SET utf8 COLLATE utf8_bin;` @@ -47,9 +49,11 @@ func (s *testSuite) TestCharsetDatabase(c *C) { tk.MustExec(testSQL) tk.MustQuery(`select @@character_set_database;`).Check(testkit.Rows("latin1")) tk.MustQuery(`select @@collation_database;`).Check(testkit.Rows("latin1_swedish_ci")) + plan.UseNewPlanner = false } func (s *testSuite) TestSetVar(c *C) { + plan.UseNewPlanner = true defer testleak.AfterTest(c)() tk := testkit.NewTestKit(c, s.store) testSQL := "SET @a = 1;" @@ -106,6 +110,7 @@ func (s *testSuite) TestSetVar(c *C) { testSQL = "SET @@global.autocommit=1, @issue998b=6;" tk.MustExec(testSQL) tk.MustQuery(`select @issue998b, @@global.autocommit;`).Check(testkit.Rows("6 1")) + plan.UseNewPlanner = false } func (s *testSuite) TestSetCharset(c *C) { diff --git a/executor/executor_test.go b/executor/executor_test.go index 48a49fc012..34e08d6309 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -127,6 +127,7 @@ func (s *testSuite) TestAdmin(c *C) { } func (s *testSuite) TestPrepared(c *C) { + plan.UseNewPlanner = true defer testleak.AfterTest(c)() tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -182,6 +183,7 @@ func (s *testSuite) TestPrepared(c *C) { exec.Fields() exec.Next() exec.Close() + plan.UseNewPlanner = false } func (s *testSuite) fillData(tk *testkit.TestKit, table string) { @@ -591,6 +593,7 @@ func (s *testSuite) TestSelectWithoutFrom(c *C) { } func (s *testSuite) TestSelectLimit(c *C) { + plan.UseNewPlanner = true defer testleak.AfterTest(c)() tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -629,6 +632,7 @@ func (s *testSuite) TestSelectLimit(c *C) { _, err := tk.Exec("select * from select_limit limit 18446744073709551616 offset 3;") c.Assert(err, NotNil) tk.MustExec("rollback") + plan.UseNewPlanner = false } func (s *testSuite) TestSelectOrderBy(c *C) { @@ -732,6 +736,7 @@ func (s *testSuite) TestSelectDistinct(c *C) { } func (s *testSuite) TestSelectErrorRow(c *C) { + plan.UseNewPlanner = true defer testleak.AfterTest(c)() tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -762,6 +767,7 @@ func (s *testSuite) TestSelectErrorRow(c *C) { c.Assert(err, NotNil) tk.MustExec("commit") + plan.UseNewPlanner = false } func (s *testSuite) TestUpdate(c *C) { @@ -1220,6 +1226,7 @@ func (s *testSuite) TestIndexReverseOrder(c *C) { } func (s *testSuite) TestTableReverseOrder(c *C) { + plan.UseNewPlanner = true defer testleak.AfterTest(c)() tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -1230,6 +1237,7 @@ func (s *testSuite) TestTableReverseOrder(c *C) { result.Check(testkit.Rows("9", "8", "7", "6", "5", "4", "3", "2", "1")) result = tk.MustQuery("select a from t where a <3 or (a >=6 and a < 8) order by a desc") result.Check(testkit.Rows("7", "6", "2", "1")) + plan.UseNewPlanner = false } func (s *testSuite) TestInSubquery(c *C) { diff --git a/executor/executor_write.go b/executor/executor_write.go index 5f6d31dcdf..1b5277bf55 100644 --- a/executor/executor_write.go +++ b/executor/executor_write.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" @@ -650,7 +651,7 @@ func (e *InsertValues) getRow(cols []*table.Column, list []ast.ExprNode, default func (e *InsertValues) getRowsSelect(cols []*table.Column) ([][]types.Datum, error) { // process `insert|replace into ... select ... from ...` - if len(e.SelectExec.Fields()) != len(cols) { + if (!plan.UseNewPlanner && len(e.SelectExec.Fields()) != len(cols)) || (plan.UseNewPlanner && len(e.SelectExec.Schema()) != len(cols)) { return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(e.SelectExec.Fields())) } var rows [][]types.Datum diff --git a/expression/expression.go b/expression/expression.go index 40ac085a53..c703c504b2 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -231,9 +231,10 @@ func NewFunction(funcName string, retType *types.FieldType, args ...Expression) log.Errorf("Function %s is not implemented.", funcName) return nil } - + funcArgs := make([]Expression, len(args)) + copy(funcArgs, args) return &ScalarFunction{ - Args: args, + Args: funcArgs, FuncName: model.NewCIStr(funcName), RetType: retType, Function: f.F} diff --git a/plan/expression_rewriter.go b/plan/expression_rewriter.go index fa86408129..02c838da93 100644 --- a/plan/expression_rewriter.go +++ b/plan/expression_rewriter.go @@ -1,6 +1,8 @@ package plan import ( + "strings" + "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/context" @@ -9,6 +11,7 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/types" ) @@ -46,8 +49,8 @@ func getRowLen(e expression.Expression) int { if f, ok := e.(*expression.ScalarFunction); ok && f.FuncName.L == ast.RowFunc { return len(f.Args) } - if f, ok := e.(*expression.Constant); ok && f.RetType.Tp == types.KindRow { - return len(f.Value.GetRow()) + if c, ok := e.(*expression.Constant); ok && c.Value.Kind() == types.KindRow { + return len(c.Value.GetRow()) } return 1 } @@ -231,7 +234,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { for _, col := range np.GetSchema() { args = append(args, col.DeepCopy()) } - rexpr = expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), args...) + rexpr = expression.NewFunction(ast.RowFunc, nil, args...) } // a in (subq) will be rewrited as a = any(subq). // a not in (subq) will be rewrited as a != all(subq). @@ -263,7 +266,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { newCols = append(newCols, col.DeepCopy()) } er.ctxStack = append(er.ctxStack, - expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), newCols...)) + expression.NewFunction(ast.RowFunc, nil, newCols...)) } else { er.ctxStack = append(er.ctxStack, np.GetSchema()[0]) } @@ -287,7 +290,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { RetType: np.GetSchema()[i].GetType()}) } er.ctxStack = append(er.ctxStack, - expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), newCols...)) + expression.NewFunction(ast.RowFunc, nil, newCols...)) } else { er.ctxStack = append(er.ctxStack, np.GetSchema()[0]) } @@ -312,8 +315,9 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool) rows = append(rows, er.ctxStack[i]) } er.ctxStack = er.ctxStack[:stkLen-length] - er.ctxStack = append(er.ctxStack, expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), rows...)) - + er.ctxStack = append(er.ctxStack, expression.NewFunction(ast.RowFunc, nil, rows...)) + case *ast.VariableExpr: + return inNode, er.rewriteVariable(v) case *ast.FuncCallExpr: er.funcCallToScalarFunc(v) case *ast.PositionExpr: @@ -326,7 +330,10 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool) er.toColumn(v) case *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause, *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr: case *ast.ValueExpr: - value := &expression.Constant{Value: v.Datum, RetType: types.NewFieldType(v.Datum.Kind())} + value := &expression.Constant{Value: v.Datum, RetType: v.Type} + er.ctxStack = append(er.ctxStack, value) + case *ast.ParamMarkerExpr: + value := &expression.Constant{Value: v.Datum, RetType: v.Type} er.ctxStack = append(er.ctxStack, value) case *ast.IsNullExpr: if getRowLen(er.ctxStack[stkLen-1]) != 1 { @@ -356,7 +363,7 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool) return retNode, false } default: - function = expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-2], er.ctxStack[stkLen-1]) + function = expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-2:]...) } er.ctxStack = er.ctxStack[:stkLen-2] er.ctxStack = append(er.ctxStack, function) @@ -387,6 +394,83 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool) return inNode, true } +func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) bool { + stkLen := len(er.ctxStack) + name := strings.ToLower(v.Name) + sessionVars := variable.GetSessionVars(er.b.ctx) + globalVars := variable.GetGlobalVarAccessor(er.b.ctx) + if !v.IsSystem { + var d types.Datum + var err error + if v.Value != nil { + d, err = er.ctxStack[stkLen-1].Eval(nil, er.b.ctx) + if err != nil { + er.err = errors.Trace(err) + return false + } + er.ctxStack = er.ctxStack[:stkLen-1] + } + if !d.IsNull() { + + strVal, err := d.ToString() + if err != nil { + er.err = errors.Trace(err) + return false + } + sessionVars.Users[name] = strings.ToLower(strVal) + er.ctxStack = append(er.ctxStack, &expression.Constant{Value: d, RetType: types.NewFieldType(mysql.TypeString)}) + } else if value, ok := sessionVars.Users[name]; ok { + er.ctxStack = append(er.ctxStack, &expression.Constant{Value: types.NewDatum(value), RetType: types.NewFieldType(mysql.TypeString)}) + } else { + // select null user vars is permitted. + er.ctxStack = append(er.ctxStack, &expression.Constant{RetType: types.NewFieldType(mysql.TypeNull)}) + } + return true + } + + sysVar, ok := variable.SysVars[name] + if !ok { + // select null sys vars is not permitted + er.err = variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name) + return false + } + if sysVar.Scope == variable.ScopeNone { + er.ctxStack = append(er.ctxStack, &expression.Constant{Value: types.NewDatum(sysVar.Value), RetType: types.NewFieldType(mysql.TypeString)}) + return true + } + + if !v.IsGlobal { + d := sessionVars.GetSystemVar(name) + if d.IsNull() { + if sysVar.Scope&variable.ScopeGlobal == 0 { + d.SetString(sysVar.Value) + } else { + // Get global system variable and fill it in session. + globalVal, err := globalVars.GetGlobalSysVar(er.b.ctx, name) + if err != nil { + er.err = errors.Trace(err) + return false + } + d.SetString(globalVal) + err = sessionVars.SetSystemVar(name, d) + if err != nil { + er.err = errors.Trace(err) + return false + } + } + } + er.ctxStack = append(er.ctxStack, &expression.Constant{Value: d, RetType: types.NewFieldType(mysql.TypeString)}) + return true + } + value, err := globalVars.GetGlobalSysVar(er.b.ctx, name) + if err != nil { + er.err = errors.Trace(err) + return false + } + er.ctxStack = append(er.ctxStack, &expression.Constant{Value: types.NewDatum(value), RetType: types.NewFieldType(mysql.TypeString)}) + return true +} + func (er *expressionRewriter) notToScalarFunc(b bool, op string, tp *types.FieldType, args ...expression.Expression) *expression.ScalarFunction { opFunc := expression.NewFunction(op, tp, args...) diff --git a/plan/newplanbuilder.go b/plan/newplanbuilder.go index ac16721efa..e208dcd5f2 100644 --- a/plan/newplanbuilder.go +++ b/plan/newplanbuilder.go @@ -325,16 +325,18 @@ func (b *planBuilder) buildNewUnion(union *ast.UnionStmt) Plan { } u.SetSchema(firstSchema) + var p Plan + p = u if union.Distinct { - return b.buildNewDistinct(u) + p = b.buildNewDistinct(u) } if union.OrderBy != nil { - return b.buildNewSort(u, union.OrderBy.Items, nil) + p = b.buildNewSort(p, union.OrderBy.Items, nil) } if union.Limit != nil { - return b.buildNewLimit(u, union.Limit) + p = b.buildNewLimit(p, union.Limit) } - return u + return p } // ByItems wraps a "by" item. diff --git a/plan/planbuilder.go b/plan/planbuilder.go index dd4e6799f8..16993a35ab 100644 --- a/plan/planbuilder.go +++ b/plan/planbuilder.go @@ -104,6 +104,9 @@ func (b *planBuilder) build(node ast.Node) Plan { } return b.buildSelect(x) case *ast.UnionStmt: + if UseNewPlanner { + return b.buildNewUnion(x) + } return b.buildUnion(x) case *ast.UpdateStmt: return b.buildUpdate(x)