diff --git a/executor/batch_point_get_test.go b/executor/batch_point_get_test.go index 9bc170107b..4b1eda8abe 100644 --- a/executor/batch_point_get_test.go +++ b/executor/batch_point_get_test.go @@ -370,3 +370,32 @@ func TestPointGetForTemporaryTable(t *testing.T) { tk.MustQuery("select * from t1 where id = 1").Check(testkit.Rows("1 1")) tk.MustQuery("select * from t1 where id = 2").Check(testkit.Rows()) } + +func TestBatchPointGetIssue46779(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("CREATE TABLE t1 (id int, c varchar(128), primary key (id)) PARTITION BY HASH (id) PARTITIONS 3;") + tk.MustExec(`insert into t1 values (1, "a"), (11, "b"), (21, "c")`) + query := "select * from t1 where id in (1, 1, 11)" + require.True(t, tk.HasPlan(query, "Batch_Point_Get")) // check if BatchPointGet is used + tk.MustQuery(query).Sort().Check(testkit.Rows("1 a", "11 b")) + query = "select * from t1 where id in (1, 11, 11, 21)" + require.True(t, tk.HasPlan(query, "Batch_Point_Get")) // check if BatchPointGet is used + tk.MustQuery(query).Sort().Check(testkit.Rows("1 a", "11 b", "21 c")) + + tk.MustExec("drop table if exists t2") + tk.MustExec(`CREATE TABLE t2 (id int, c varchar(128), primary key (id)) partition by range (id)( + partition p0 values less than (10), + partition p1 values less than (20), + partition p2 values less than (30));`) + tk.MustExec(`insert into t2 values (1, "a"), (11, "b"), (21, "c")`) + query = "select * from t2 where id in (1, 1, 11)" + require.True(t, tk.HasPlan(query, "Batch_Point_Get")) // check if BatchPointGet is used + tk.MustQuery(query).Sort().Check(testkit.Rows("1 a", "11 b")) + require.True(t, tk.HasPlan(query, "Batch_Point_Get")) // check if BatchPointGet is used + query = "select * from t2 where id in (1, 11, 11, 21)" + tk.MustQuery(query).Sort().Check(testkit.Rows("1 a", "11 b", "21 c")) +} diff --git a/executor/builder.go b/executor/builder.go index 7fe26157b4..27166a2160 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -5226,16 +5226,22 @@ func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan // `SELECT a FROM t WHERE a IN (1, 1, 2, 1, 2)` should not return duplicated rows handles := make([]kv.Handle, 0, len(plan.Handles)) dedup := kv.NewHandleMap() + // Used for clear paritionIDs of duplicated rows. + dupPartPos := 0 if plan.IndexInfo == nil { - for _, handle := range plan.Handles { + for idx, handle := range plan.Handles { if _, found := dedup.Get(handle); found { continue } dedup.Set(handle, true) handles = append(handles, handle) + if len(plan.PartitionIDs) > 0 { + e.planPhysIDs[dupPartPos] = e.planPhysIDs[idx] + dupPartPos++ + } } } else { - for _, value := range plan.IndexValues { + for idx, value := range plan.IndexValues { if datumsContainNull(value) { continue } @@ -5257,9 +5263,16 @@ func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan } dedup.Set(handle, true) handles = append(handles, handle) + if len(plan.PartitionIDs) > 0 { + e.planPhysIDs[dupPartPos] = e.planPhysIDs[idx] + dupPartPos++ + } } } e.handles = handles + if dupPartPos > 0 { + e.planPhysIDs = e.planPhysIDs[:dupPartPos] + } capacity = len(e.handles) } e.Base().SetInitCap(capacity)