// Copyright 2021 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // See the License for the specific language governing permissions and // limitations under the License. package executor_test import ( "context" "fmt" "math/rand" "sort" "github.com/pingcap/check" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/testkit" ) var _ = check.Suite(&CTETestSuite{}) type CTETestSuite struct { store kv.Storage dom *domain.Domain sessionCtx sessionctx.Context session session.Session ctx context.Context } func (test *CTETestSuite) SetUpSuite(c *check.C) { var err error test.store, err = mockstore.NewMockStore() c.Assert(err, check.IsNil) test.dom, err = session.BootstrapSession(test.store) c.Assert(err, check.IsNil) test.sessionCtx = mock.NewContext() test.session, err = session.CreateSession4Test(test.store) c.Assert(err, check.IsNil) test.session.SetConnectionID(0) test.ctx = context.Background() } func (test *CTETestSuite) TearDownSuite(c *check.C) { test.dom.Close() test.store.Close() } func (test *CTETestSuite) TestBasicCTE(c *check.C) { tk := testkit.NewTestKit(c, test.store) tk.MustExec("use test") rows := tk.MustQuery("with recursive cte1 as (" + "select 1 c1 " + "union all " + "select c1 + 1 c1 from cte1 where c1 < 5) " + "select * from cte1") rows.Check(testkit.Rows("1", "2", "3", "4", "5")) // Two seed parts. rows = tk.MustQuery("with recursive cte1 as (" + "select 1 c1 " + "union all " + "select 2 c1 " + "union all " + "select c1 + 1 c1 from cte1 where c1 < 10) " + "select * from cte1 order by c1") rows.Check(testkit.Rows("1", "2", "2", "3", "3", "4", "4", "5", "5", "6", "6", "7", "7", "8", "8", "9", "9", "10", "10")) // Two recursive parts. rows = tk.MustQuery("with recursive cte1 as (" + "select 1 c1 " + "union all " + "select 2 c1 " + "union all " + "select c1 + 1 c1 from cte1 where c1 < 3 " + "union all " + "select c1 + 2 c1 from cte1 where c1 < 5) " + "select * from cte1 order by c1") rows.Check(testkit.Rows("1", "2", "2", "3", "3", "3", "4", "4", "5", "5", "5", "6", "6")) tk.MustExec("drop table if exists t1;") tk.MustExec("create table t1(a int);") tk.MustExec("insert into t1 values(1);") tk.MustExec("insert into t1 values(2);") rows = tk.MustQuery("SELECT * FROM t1 dt WHERE EXISTS(WITH RECURSIVE qn AS (SELECT a*0 AS b UNION ALL SELECT b+1 FROM qn WHERE b=0) SELECT * FROM qn WHERE b=a);") rows.Check(testkit.Rows("1")) rows = tk.MustQuery("SELECT * FROM t1 dt WHERE EXISTS( WITH RECURSIVE qn AS (SELECT a*0 AS b UNION ALL SELECT b+1 FROM qn WHERE b=0 or b = 1) SELECT * FROM qn WHERE b=a );") rows.Check(testkit.Rows("1", "2")) } func (test *CTETestSuite) TestSpillToDisk(c *check.C) { tk := testkit.NewTestKit(c, test.store) tk.MustExec("use test;") c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/testCTEStorageSpill", "return(true)"), check.IsNil) defer func() { c.Assert(failpoint.Disable("github.com/pingcap/tidb/executor/testCTEStorageSpill"), check.IsNil) }() insertStr := "insert into t1 values(0, 0)" for i := 1; i < 5000; i++ { insertStr += fmt.Sprintf(", (%d, %d)", i, i) } tk.MustExec("drop table if exists t1;") tk.MustExec("create table t1(c1 int, c2 int);") tk.MustExec(insertStr) tk.MustExec("set tidb_mem_quota_query = 80000;") rows := tk.MustQuery("with recursive cte1 as ( " + "select c1 from t1 " + "union " + "select c1 + 1 c1 from cte1 where c1 < 5000) " + "select c1 from cte1;") memTracker := tk.Se.GetSessionVars().StmtCtx.MemTracker diskTracker := tk.Se.GetSessionVars().StmtCtx.DiskTracker c.Assert(memTracker.MaxConsumed(), check.Greater, int64(0)) c.Assert(diskTracker.MaxConsumed(), check.Greater, int64(0)) rowNum := 5000 var resRows []string for i := 0; i <= rowNum; i++ { resRows = append(resRows, fmt.Sprintf("%d", i)) } rows.Check(testkit.Rows(resRows...)) // Use duplicated rows to test UNION DISTINCT. tk.MustExec("set tidb_mem_quota_query = 1073741824;") insertStr = "insert into t1 values(0, 0)" vals := make([]int, rowNum) vals[0] = 0 for i := 1; i < rowNum; i++ { v := rand.Intn(100) vals[i] = v insertStr += fmt.Sprintf(", (%d, %d)", v, v) } tk.MustExec("drop table if exists t1;") tk.MustExec("create table t1(c1 int, c2 int);") tk.MustExec(insertStr) tk.MustExec("set tidb_mem_quota_query = 80000;") tk.MustExec("set cte_max_recursion_depth = 500000;") rows = tk.MustQuery("with recursive cte1 as ( " + "select c1 from t1 " + "union " + "select c1 + 1 c1 from cte1 where c1 < 5000) " + "select c1 from cte1 order by c1;") memTracker = tk.Se.GetSessionVars().StmtCtx.MemTracker diskTracker = tk.Se.GetSessionVars().StmtCtx.DiskTracker c.Assert(memTracker.MaxConsumed(), check.Greater, int64(0)) c.Assert(diskTracker.MaxConsumed(), check.Greater, int64(0)) sort.Ints(vals) resRows = make([]string, 0, rowNum) for i := vals[0]; i <= rowNum; i++ { resRows = append(resRows, fmt.Sprintf("%d", i)) } rows.Check(testkit.Rows(resRows...)) } func (test *CTETestSuite) TestUnionDistinct(c *check.C) { tk := testkit.NewTestKit(c, test.store) tk.MustExec("use test;") // Basic test. UNION/UNION ALL intersects. rows := tk.MustQuery("with recursive cte1(c1) as (select 1 union select 1 union select 1 union all select c1 + 1 from cte1 where c1 < 3) select * from cte1 order by c1;") rows.Check(testkit.Rows("1", "2", "3")) rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union all select 1 union select 1 union all select c1 + 1 from cte1 where c1 < 3) select * from cte1 order by c1;") rows.Check(testkit.Rows("1", "2", "3")) tk.MustExec("drop table if exists t1;") tk.MustExec("create table t1(c1 int, c2 int);") tk.MustExec("insert into t1 values(1, 1), (1, 2), (2, 2);") rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from t1) select * from cte1 order by c1;") rows.Check(testkit.Rows("1", "2", "3")) tk.MustExec("drop table if exists t1;") tk.MustExec("create table t1(c1 int);") tk.MustExec("insert into t1 values(1), (1), (1), (2), (2), (2);") rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 where c1 < 4) select * from cte1 order by c1;") rows.Check(testkit.Rows("1", "2", "3", "4")) } func (test *CTETestSuite) TestCTEMaxRecursionDepth(c *check.C) { tk := testkit.NewTestKit(c, test.store) tk.MustExec("use test;") tk.MustExec("set @@cte_max_recursion_depth = -1;") err := tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 100) select * from cte1;") c.Assert(err, check.NotNil) c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value") // If there is no recursive part, query runs ok. rows := tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") rows.Check(testkit.Rows("1", "2")) rows = tk.MustQuery("with cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") rows.Check(testkit.Rows("1", "2")) tk.MustExec("set @@cte_max_recursion_depth = 0;") err = tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 0) select * from cte1;") c.Assert(err, check.NotNil) c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value") err = tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 1) select * from cte1;") c.Assert(err, check.NotNil) c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value") // If there is no recursive part, query runs ok. rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") rows.Check(testkit.Rows("1", "2")) rows = tk.MustQuery("with cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") rows.Check(testkit.Rows("1", "2")) tk.MustExec("set @@cte_max_recursion_depth = 1;") rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 0) select * from cte1;") rows.Check(testkit.Rows("1")) rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 1) select * from cte1;") rows.Check(testkit.Rows("1")) err = tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 2) select * from cte1;") c.Assert(err, check.NotNil) c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 2 iterations. Try increasing @@cte_max_recursion_depth to a larger value") // If there is no recursive part, query runs ok. rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") rows.Check(testkit.Rows("1", "2")) rows = tk.MustQuery("with cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") rows.Check(testkit.Rows("1", "2")) }