diff --git a/executor/batch_point_get.go b/executor/batch_point_get.go index 83d3746ca3..c6a6195b69 100644 --- a/executor/batch_point_get.go +++ b/executor/batch_point_get.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" + "github.com/pingcap/tidb/util/math" "github.com/pingcap/tidb/util/rowcodec" ) @@ -35,6 +36,8 @@ type BatchPointGetExec struct { tblInfo *model.TableInfo idxInfo *model.IndexInfo handles []int64 + physIDs []int64 + partPos int idxVals [][]types.Datum startTS uint64 snapshotTS uint64 @@ -111,7 +114,8 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { dedup := make(map[hack.MutableString]struct{}) keys := make([]kv.Key, 0, len(e.idxVals)) for _, idxVals := range e.idxVals { - idxKey, err1 := encodeIndexKey(e.base(), e.tblInfo, e.idxInfo, idxVals, e.tblInfo.ID) + physID := getPhysID(e.tblInfo, idxVals[e.partPos].GetInt64()) + idxKey, err1 := encodeIndexKey(e.base(), e.tblInfo, e.idxInfo, idxVals, physID) if err1 != nil && !kv.ErrNotExist.Equal(err1) { return err1 } @@ -139,6 +143,9 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { } e.handles = make([]int64, 0, len(keys)) + if e.tblInfo.Partition != nil { + e.physIDs = make([]int64, 0, len(keys)) + } for _, key := range keys { handleVal := handleVals[string(key)] if len(handleVal) == 0 { @@ -149,6 +156,9 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { return err1 } e.handles = append(e.handles, handle) + if e.tblInfo.Partition != nil { + e.physIDs = append(e.physIDs, tablecodec.DecodeTableID(key)) + } } // The injection is used to simulate following scenario: @@ -175,7 +185,13 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { keys := make([]kv.Key, len(e.handles)) for i, handle := range e.handles { - key := tablecodec.EncodeRowKeyWithHandle(e.tblInfo.ID, handle) + var tID int64 + if len(e.physIDs) > 0 { + tID = e.physIDs[i] + } else { + tID = getPhysID(e.tblInfo, handle) + } + key := tablecodec.EncodeRowKeyWithHandle(tID, handle) keys[i] = key } @@ -255,3 +271,12 @@ func (e *BatchPointGetExec) lockKeys(ctx context.Context, keys []kv.Key) (map[st } return nil, nil } + +func getPhysID(tblInfo *model.TableInfo, val int64) int64 { + pi := tblInfo.Partition + if pi == nil { + return tblInfo.ID + } + partIdx := math.Abs(val) % int64(pi.Num) + return pi.Definitions[partIdx].ID +} diff --git a/executor/builder.go b/executor/builder.go index e27290f01f..f53e45c80c 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -2945,6 +2945,7 @@ func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan desc: plan.Desc, lock: plan.Lock, waitTime: plan.LockWaitTime, + partPos: plan.PartitionColPos, } if e.lock { b.hasLock = true diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 09d1ae4b1b..2e5905ab4f 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -207,6 +207,7 @@ type BatchPointGetPlan struct { HandleParams []*driver.ParamMarkerExpr IndexValues [][]types.Datum IndexValueParams [][]*driver.ParamMarkerExpr + PartitionColPos int KeepOrder bool Desc bool Lock bool @@ -377,12 +378,15 @@ func getLockWaitTime(ctx sessionctx.Context, lockTp ast.SelectLockType) (lock bo func newBatchPointGetPlan( ctx sessionctx.Context, patternInExpr *ast.PatternInExpr, - tryHandle bool, fieldType *types.FieldType, - tbl *model.TableInfo, schema *expression.Schema, + handleCol *model.ColumnInfo, tbl *model.TableInfo, schema *expression.Schema, names []*types.FieldName, whereColNames []string, ) *BatchPointGetPlan { statsInfo := &property.StatsInfo{RowCount: float64(len(patternInExpr.List))} - if tryHandle && fieldType != nil { + partitionColName := getHashPartitionColumnName(ctx, tbl) + if tbl.GetPartitionInfo() != nil && partitionColName == nil { + return nil + } + if handleCol != nil { var handles = make([]int64, len(patternInExpr.List)) var handleParams = make([]*driver.ParamMarkerExpr, len(patternInExpr.List)) for i, item := range patternInExpr.List { @@ -404,7 +408,7 @@ func newBatchPointGetPlan( if d.IsNull() { return nil } - intDatum, err := d.ConvertTo(ctx.GetSessionVars().StmtCtx, fieldType) + intDatum, err := d.ConvertTo(ctx.GetSessionVars().StmtCtx, &handleCol.FieldType) if err != nil { return nil } @@ -499,6 +503,7 @@ func newBatchPointGetPlan( IndexInfo: matchIdxInfo, IndexValues: indexValues, IndexValueParams: indexValueParams, + PartitionColPos: getPartitionColumnPos(matchIdxInfo, partitionColName), }.Init(ctx, statsInfo, schema, names, 0) } @@ -522,14 +527,6 @@ func tryWhereIn2BatchPointGet(ctx sessionctx.Context, selStmt *ast.SelectStmt) * return nil } - // Do not handle partitioned table. - // Table partition implementation translates LogicalPlan from `DataSource` to - // `Union -> DataSource` in the logical plan optimization pass, since BatchPointGetPlan - // bypass the logical plan optimization, it can't support partitioned table. - if tbl.GetPartitionInfo() != nil { - return nil - } - for _, col := range tbl.Columns { if col.IsGenerated() || col.State != model.StatePublic { return nil @@ -542,8 +539,7 @@ func tryWhereIn2BatchPointGet(ctx sessionctx.Context, selStmt *ast.SelectStmt) * } var ( - tryHandle bool - fieldType *types.FieldType + handleCol *model.ColumnInfo whereColNames []string ) @@ -561,14 +557,13 @@ func tryWhereIn2BatchPointGet(ctx sessionctx.Context, selStmt *ast.SelectStmt) * if tbl.PKIsHandle { for _, col := range tbl.Columns { if mysql.HasPriKeyFlag(col.Flag) && col.Name.L == colName.Name.Name.L { - tryHandle = true - fieldType = &col.FieldType + handleCol = col whereColNames = append(whereColNames, col.Name.L) break } } } - if !tryHandle { + if handleCol == nil { // Downgrade to use unique index whereColNames = append(whereColNames, colName.Name.Name.L) } @@ -588,7 +583,7 @@ func tryWhereIn2BatchPointGet(ctx sessionctx.Context, selStmt *ast.SelectStmt) * return nil } - p := newBatchPointGetPlan(ctx, in, tryHandle, fieldType, tbl, schema, names, whereColNames) + p := newBatchPointGetPlan(ctx, in, handleCol, tbl, schema, names, whereColNames) if p == nil { return nil } @@ -627,10 +622,6 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if tbl == nil { return nil } - // Do not handle partitioned table. - // Table partition implementation translates LogicalPlan from `DataSource` to - // `Union -> DataSource` in the logical plan optimization pass, since PointGetPlan - // bypass the logical plan optimization, it can't support partitioned table. pi := tbl.GetPartitionInfo() for _, col := range tbl.Columns { // Do not handle generated columns. @@ -1103,27 +1094,16 @@ func (p *PointGetPlan) findHandleCol() *expression.Column { } func getPartitionInfo(ctx sessionctx.Context, tbl *model.TableInfo, pairs []nameValuePair) (*model.PartitionDefinition, int) { - is := infoschema.GetInfoSchema(ctx) - table, ok := is.TableByID(tbl.ID) - if !ok { + partitionColName := getHashPartitionColumnName(ctx, tbl) + if partitionColName == nil { return nil, 0 } pi := tbl.Partition - if partitionTable, ok := table.(partitionTable); ok { - // PartitionExpr don't need columns and names for hash partition. - partitionExpr, err := partitionTable.PartitionExpr(ctx, nil, nil) - if err != nil { - return nil, 0 - } - expr := partitionExpr.OrigExpr - if col, ok := expr.(*ast.ColumnNameExpr); ok { - for i, pair := range pairs { - if col.Name.Name.L == pair.colName { - val := pair.value.GetInt64() - pos := math.Abs(val) % int64(pi.Num) - return &pi.Definitions[pos], i - } - } + for i, pair := range pairs { + if partitionColName.Name.L == pair.colName { + val := pair.value.GetInt64() + pos := math.Abs(val) % int64(pi.Num) + return &pi.Definitions[pos], i } } return nil, 0 @@ -1137,3 +1117,42 @@ func findPartitionIdx(idxInfo *model.IndexInfo, pos int, pairs []nameValuePair) } return 0 } + +// getPartitionColumnPos gets the partition column's position in the index. +func getPartitionColumnPos(idx *model.IndexInfo, partitionColName *ast.ColumnName) int { + if partitionColName == nil { + return 0 + } + for i, idxCol := range idx.Columns { + if partitionColName.Name.L == idxCol.Name.L { + return i + } + } + panic("unique index must include all partition columns") +} + +func getHashPartitionColumnName(ctx sessionctx.Context, tbl *model.TableInfo) *ast.ColumnName { + pi := tbl.GetPartitionInfo() + if pi == nil { + return nil + } + if pi.Type != model.PartitionTypeHash { + return nil + } + is := infoschema.GetInfoSchema(ctx) + table, ok := is.TableByID(tbl.ID) + if !ok { + return nil + } + // PartitionExpr don't need columns and names for hash partition. + partitionExpr, err := table.(partitionTable).PartitionExpr(ctx, nil, nil) + if err != nil { + return nil + } + expr := partitionExpr.OrigExpr + col, ok := expr.(*ast.ColumnNameExpr) + if !ok { + return nil + } + return col.Name +} diff --git a/planner/core/point_get_plan_test.go b/planner/core/point_get_plan_test.go index 53280ce6c9..522c87a1f9 100644 --- a/planner/core/point_get_plan_test.go +++ b/planner/core/point_get_plan_test.go @@ -371,3 +371,36 @@ func (s *testPointGetSuite) TestBatchPointGetPlanCache(c *C) { "4 4", )) } + +func (s *testPointGetSuite) TestBatchPointGetPartition(c *C) { + tk := testkit.NewTestKit(c, s.store) + orgEnable := core.PreparedPlanCacheEnabled() + defer func() { + core.SetPreparedPlanCache(orgEnable) + }() + core.SetPreparedPlanCache(true) + + var err error + tk.Se, err = session.CreateSession4TestWithOpt(s.store, &session.Opt{ + PreparedPlanCache: kvcache.NewSimpleLRUCache(100, 0.1, math.MaxUint64), + }) + c.Assert(err, IsNil) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int primary key, b int) PARTITION BY HASH(a) PARTITIONS 4") + tk.MustExec("insert into t values (1, 1), (2, 2), (3, 3), (4, 4)") + tk.MustQuery("explain select * from t where a in (1, 2, 3, 4)").Check(testkit.Rows( + "Batch_Point_Get_1 4.00 root table:t, handle:[1 2 3 4], keep order:false, desc:false", + )) + tk.MustQuery("select * from t where a in (1, 2, 3, 4)").Check(testkit.Rows("1 1", "2 2", "3 3", "4 4")) + + tk.MustExec("drop table t") + tk.MustExec("create table t(a int, b int, c int, primary key (a, b)) PARTITION BY HASH(a) PARTITIONS 4") + tk.MustExec("insert into t values (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)") + tk.MustQuery("explain select * from t where (a, b) in ((1, 1), (2, 2), (3, 3), (4, 4))").Check(testkit.Rows( + "Batch_Point_Get_1 4.00 root table:t, index:a b, keep order:false, desc:false", + )) + tk.MustQuery("select * from t where (a, b) in ((1, 1), (2, 2), (3, 3), (4, 4))"). + Check(testkit.Rows("1 1 1", "2 2 2", "3 3 3", "4 4 4")) +}