planner: take logical schema producer into logical operator's hash generation (#57323)

ref pingcap/tidb#51664
This commit is contained in:
Arenatlx
2024-11-13 14:11:29 +08:00
committed by GitHub
parent 65281ad307
commit 4cca1ffbc5
12 changed files with 262 additions and 102 deletions

View File

@ -26,6 +26,7 @@ go_library(
"//pkg/parser/mysql",
"//pkg/parser/terror",
"//pkg/parser/types",
"//pkg/planner/cascades/base",
"//pkg/util/intest",
"@com_github_pingcap_errors//:errors",
"@com_github_tikv_pd_client//http",

View File

@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/duration"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
)
// ExtraHandleID is the column ID of column which we need to append to schema to occupy the handle's position
@ -196,6 +197,27 @@ type TableInfo struct {
DBID int64 `json:"-"`
}
// Hash64 implement HashEquals interface.
func (t *TableInfo) Hash64(h base.Hasher) {
h.HashInt64(t.ID)
}
// Equals implements HashEquals interface.
func (t *TableInfo) Equals(other any) bool {
// any(nil) can still be converted as (*TableInfo)(nil)
t2, ok := other.(*TableInfo)
if !ok {
return false
}
if t == nil {
return t2 == nil
}
if t2 == nil {
return false
}
return t.ID == t2.ID
}
// SepAutoInc decides whether _rowid and auto_increment id use separate allocator.
func (t *TableInfo) SepAutoInc() bool {
return t.Version >= TableInfoVersion5 && t.AutoIDCache == 1

View File

@ -6,6 +6,7 @@ go_library(
importpath = "github.com/pingcap/tidb/pkg/planner/core/generator/hash64_equals",
visibility = ["//visibility:private"],
deps = [
"//pkg/parser/types",
"//pkg/planner/cascades/base",
"//pkg/planner/core/operator/logicalop",
],

View File

@ -23,6 +23,7 @@ import (
"reflect"
"strings"
"github.com/pingcap/tidb/pkg/parser/types"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/planner/core/operator/logicalop"
)
@ -36,7 +37,7 @@ import (
// If a field is not tagged, then it will be skipped.
func GenHash64Equals4LogicalOps() ([]byte, error) {
var structures = []any{logicalop.LogicalJoin{}, logicalop.LogicalAggregation{}, logicalop.LogicalApply{},
logicalop.LogicalExpand{}, logicalop.LogicalLimit{}, logicalop.LogicalMaxOneRow{}}
logicalop.LogicalExpand{}, logicalop.LogicalLimit{}, logicalop.LogicalMaxOneRow{}, logicalop.DataSource{}}
c := new(cc)
c.write(codeGenHash64EqualsPrefix)
for _, s := range structures {
@ -49,7 +50,14 @@ func GenHash64Equals4LogicalOps() ([]byte, error) {
return c.format()
}
// IHashEquals is the interface for hash64 and equals inside parser pkg.
type IHashEquals interface {
Hash64(h types.IHasher)
Equals(other any) bool
}
var hashEqualsType = reflect.TypeOf((*base.HashEquals)(nil)).Elem()
var iHashEqualsType = reflect.TypeOf((*IHashEquals)(nil)).Elem()
func genHash64EqualsForLogicalOps(x any) ([]byte, error) {
c := new(cc)
@ -78,9 +86,10 @@ func genHash64EqualsForLogicalOps(x any) ([]byte, error) {
// for Equals function.
c.write("// Equals implements the Hash64Equals interface, only receive *%v pointer.", vType.Name())
c.write("func (op *%v) Equals(other any) bool {", vType.Name())
c.write("if other == nil { return false }")
c.write("op2, ok := other.(*%v)", vType.Name())
c.write("if !ok { return false }")
c.write("if op == nil { return op2 == nil }")
c.write("if op2 == nil { return false }")
hasValidField := false
for i := 0; i < vType.NumField(); i++ {
f := vType.Field(i)
@ -114,6 +123,8 @@ func logicalOpName2PlanCodecString(name string) string {
return "plancodec.TypeLimit"
case "LogicalMaxOneRow":
return "plancodec.TypeMaxOneRow"
case "DataSource":
return "plancodec.TypeDataSource"
default:
return ""
}
@ -149,7 +160,8 @@ func (c *cc) EqualsElement(fType reflect.Type, lhs, rhs string, i string) {
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64:
c.write("if %v != %v {return false}", lhs, rhs)
default:
if fType.Implements(hashEqualsType) || reflect.PtrTo(fType).Implements(hashEqualsType) {
if fType.Implements(hashEqualsType) || fType.Implements(iHashEqualsType) ||
reflect.PtrTo(fType).Implements(hashEqualsType) || reflect.PtrTo(fType).Implements(iHashEqualsType) {
if fType.Kind() == reflect.Struct {
rhs = "&" + rhs
}
@ -183,7 +195,8 @@ func (c *cc) Hash64Element(fType reflect.Type, callName string) {
case reflect.Float32, reflect.Float64:
c.write("h.HashFloat64(float64(%v))", callName)
default:
if fType.Implements(hashEqualsType) || reflect.PtrTo(fType).Implements(hashEqualsType) {
if fType.Implements(hashEqualsType) || fType.Implements(iHashEqualsType) ||
reflect.PtrTo(fType).Implements(hashEqualsType) || reflect.PtrTo(fType).Implements(iHashEqualsType) {
c.write("%v.Hash64(h)", callName)
} else {
panic("doesn't support element type" + fType.Kind().String())

View File

@ -24,6 +24,7 @@ import (
// Hash64 implements the Hash64Equals interface.
func (op *LogicalJoin) Hash64(h base.Hasher) {
h.HashString(plancodec.TypeJoin)
op.LogicalSchemaProducer.Hash64(h)
h.HashInt64(int64(op.JoinType))
if op.EqualConditions == nil {
h.HashByte(base.NilFlag)
@ -74,13 +75,19 @@ func (op *LogicalJoin) Hash64(h base.Hasher) {
// Equals implements the Hash64Equals interface, only receive *LogicalJoin pointer.
func (op *LogicalJoin) Equals(other any) bool {
if other == nil {
return false
}
op2, ok := other.(*LogicalJoin)
if !ok {
return false
}
if op == nil {
return op2 == nil
}
if op2 == nil {
return false
}
if !op.LogicalSchemaProducer.Equals(&op2.LogicalSchemaProducer) {
return false
}
if op.JoinType != op2.JoinType {
return false
}
@ -130,6 +137,7 @@ func (op *LogicalJoin) Equals(other any) bool {
// Hash64 implements the Hash64Equals interface.
func (op *LogicalAggregation) Hash64(h base.Hasher) {
h.HashString(plancodec.TypeAgg)
op.LogicalSchemaProducer.Hash64(h)
if op.AggFuncs == nil {
h.HashByte(base.NilFlag)
} else {
@ -164,13 +172,19 @@ func (op *LogicalAggregation) Hash64(h base.Hasher) {
// Equals implements the Hash64Equals interface, only receive *LogicalAggregation pointer.
func (op *LogicalAggregation) Equals(other any) bool {
if other == nil {
return false
}
op2, ok := other.(*LogicalAggregation)
if !ok {
return false
}
if op == nil {
return op2 == nil
}
if op2 == nil {
return false
}
if !op.LogicalSchemaProducer.Equals(&op2.LogicalSchemaProducer) {
return false
}
if (op.AggFuncs == nil && op2.AggFuncs != nil) || (op.AggFuncs != nil && op2.AggFuncs == nil) || len(op.AggFuncs) != len(op2.AggFuncs) {
return false
}
@ -221,13 +235,16 @@ func (op *LogicalApply) Hash64(h base.Hasher) {
// Equals implements the Hash64Equals interface, only receive *LogicalApply pointer.
func (op *LogicalApply) Equals(other any) bool {
if other == nil {
return false
}
op2, ok := other.(*LogicalApply)
if !ok {
return false
}
if op == nil {
return op2 == nil
}
if op2 == nil {
return false
}
if !op.LogicalJoin.Equals(&op2.LogicalJoin) {
return false
}
@ -248,6 +265,7 @@ func (op *LogicalApply) Equals(other any) bool {
// Hash64 implements the Hash64Equals interface.
func (op *LogicalExpand) Hash64(h base.Hasher) {
h.HashString(plancodec.TypeExpand)
op.LogicalSchemaProducer.Hash64(h)
if op.DistinctGroupByCol == nil {
h.HashByte(base.NilFlag)
} else {
@ -310,13 +328,19 @@ func (op *LogicalExpand) Hash64(h base.Hasher) {
// Equals implements the Hash64Equals interface, only receive *LogicalExpand pointer.
func (op *LogicalExpand) Equals(other any) bool {
if other == nil {
return false
}
op2, ok := other.(*LogicalExpand)
if !ok {
return false
}
if op == nil {
return op2 == nil
}
if op2 == nil {
return false
}
if !op.LogicalSchemaProducer.Equals(&op2.LogicalSchemaProducer) {
return false
}
if (op.DistinctGroupByCol == nil && op2.DistinctGroupByCol != nil) || (op.DistinctGroupByCol != nil && op2.DistinctGroupByCol == nil) || len(op.DistinctGroupByCol) != len(op2.DistinctGroupByCol) {
return false
}
@ -394,13 +418,16 @@ func (op *LogicalLimit) Hash64(h base.Hasher) {
// Equals implements the Hash64Equals interface, only receive *LogicalLimit pointer.
func (op *LogicalLimit) Equals(other any) bool {
if other == nil {
return false
}
op2, ok := other.(*LogicalLimit)
if !ok {
return false
}
if op == nil {
return op2 == nil
}
if op2 == nil {
return false
}
if (op.PartitionBy == nil && op2.PartitionBy != nil) || (op.PartitionBy != nil && op2.PartitionBy == nil) || len(op.PartitionBy) != len(op2.PartitionBy) {
return false
}
@ -425,13 +452,100 @@ func (op *LogicalMaxOneRow) Hash64(h base.Hasher) {
// Equals implements the Hash64Equals interface, only receive *LogicalMaxOneRow pointer.
func (op *LogicalMaxOneRow) Equals(other any) bool {
if other == nil {
return false
}
op2, ok := other.(*LogicalMaxOneRow)
if !ok {
return false
}
if op == nil {
return op2 == nil
}
if op2 == nil {
return false
}
_ = op2
return true
}
// Hash64 implements the Hash64Equals interface.
func (op *DataSource) Hash64(h base.Hasher) {
h.HashString(plancodec.TypeDataSource)
op.LogicalSchemaProducer.Hash64(h)
if op.TableInfo == nil {
h.HashByte(base.NilFlag)
} else {
h.HashByte(base.NotNilFlag)
op.TableInfo.Hash64(h)
}
if op.TableAsName == nil {
h.HashByte(base.NilFlag)
} else {
h.HashByte(base.NotNilFlag)
op.TableAsName.Hash64(h)
}
if op.PushedDownConds == nil {
h.HashByte(base.NilFlag)
} else {
h.HashByte(base.NotNilFlag)
h.HashInt(len(op.PushedDownConds))
for _, one := range op.PushedDownConds {
one.Hash64(h)
}
}
if op.AllConds == nil {
h.HashByte(base.NilFlag)
} else {
h.HashByte(base.NotNilFlag)
h.HashInt(len(op.AllConds))
for _, one := range op.AllConds {
one.Hash64(h)
}
}
h.HashInt64(int64(op.PreferStoreType))
h.HashBool(op.IsForUpdateRead)
}
// Equals implements the Hash64Equals interface, only receive *DataSource pointer.
func (op *DataSource) Equals(other any) bool {
op2, ok := other.(*DataSource)
if !ok {
return false
}
if op == nil {
return op2 == nil
}
if op2 == nil {
return false
}
if !op.LogicalSchemaProducer.Equals(&op2.LogicalSchemaProducer) {
return false
}
if !op.TableInfo.Equals(op2.TableInfo) {
return false
}
if !op.TableAsName.Equals(op2.TableAsName) {
return false
}
if (op.PushedDownConds == nil && op2.PushedDownConds != nil) || (op.PushedDownConds != nil && op2.PushedDownConds == nil) || len(op.PushedDownConds) != len(op2.PushedDownConds) {
return false
}
for i, one := range op.PushedDownConds {
if !one.Equals(op2.PushedDownConds[i]) {
return false
}
}
if (op.AllConds == nil && op2.AllConds != nil) || (op.AllConds != nil && op2.AllConds == nil) || len(op.AllConds) != len(op2.AllConds) {
return false
}
for i, one := range op.AllConds {
if !one.Equals(op2.AllConds[i]) {
return false
}
}
if op.PreferStoreType != op2.PreferStoreType {
return false
}
if op.IsForUpdateRead != op2.IsForUpdateRead {
return false
}
return true
}

View File

@ -38,7 +38,7 @@ import (
// LogicalAggregation represents an aggregate plan.
type LogicalAggregation struct {
LogicalSchemaProducer
LogicalSchemaProducer `hash64-equals:"true"`
AggFuncs []*aggregation.AggFuncDesc `hash64-equals:"true"`
GroupByItems []expression.Expression `hash64-equals:"true"`

View File

@ -26,7 +26,6 @@ import (
"github.com/pingcap/tidb/pkg/parser/ast"
pmodel "github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
base2 "github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/planner/core/base"
"github.com/pingcap/tidb/pkg/planner/core/constraint"
ruleutil "github.com/pingcap/tidb/pkg/planner/core/rule/util"
@ -48,23 +47,23 @@ import (
// DataSource represents a tableScan without condition push down.
type DataSource struct {
LogicalSchemaProducer
LogicalSchemaProducer `hash64-equals:"true"`
AstIndexHints []*ast.IndexHint
IndexHints []h.HintedIndex
Table table.Table
TableInfo *model.TableInfo
TableInfo *model.TableInfo `hash64-equals:"true"`
Columns []*model.ColumnInfo
DBName pmodel.CIStr
TableAsName *pmodel.CIStr
TableAsName *pmodel.CIStr `hash64-equals:"true"`
// IndexMergeHints are the hint for indexmerge.
IndexMergeHints []h.HintedIndex
// PushedDownConds are the conditions that will be pushed down to coprocessor.
PushedDownConds []expression.Expression
PushedDownConds []expression.Expression `hash64-equals:"true"`
// AllConds contains all the filters on this table. For now it's maintained
// in predicate push down and used in partition pruning/index merge.
AllConds []expression.Expression
AllConds []expression.Expression `hash64-equals:"true"`
StatisticTable *statistics.Table
TableStats *property.StatsInfo
@ -92,7 +91,7 @@ type DataSource struct {
// it is converted from StatisticTable, and used for IO/network cost estimating.
TblColHists *statistics.HistColl
// PreferStoreType means the DataSource is enforced to which storage.
PreferStoreType int
PreferStoreType int `hash64-equals:"true"`
// PreferPartitions store the map, the key represents store type, the value represents the partition name list.
PreferPartitions map[int][]pmodel.CIStr
SampleInfo *tablesampler.TableSampleInfo
@ -100,7 +99,7 @@ type DataSource struct {
// IsForUpdateRead should be true in either of the following situations
// 1. use `inside insert`, `update`, `delete` or `select for update` statement
// 2. isolation level is RC
IsForUpdateRead bool
IsForUpdateRead bool `hash64-equals:"true"`
// contain unique index and the first field is tidb_shard(),
// such as (tidb_shard(a), a ...), the fields are more than 2
@ -122,74 +121,6 @@ func (ds DataSource) Init(ctx base.PlanContext, offset int) *DataSource {
return &ds
}
// ************************ start implementation of HashEquals interface ************************
// Hash64 implements base.HashEquals interface.
func (ds *DataSource) Hash64(h base2.Hasher) {
// hash the key elements to identify this datasource.
h.HashString(plancodec.TypeDataSource)
// table related.
if ds.TableInfo == nil {
h.HashByte(base2.NilFlag)
} else {
h.HashByte(base2.NotNilFlag)
h.HashInt64(ds.TableInfo.ID)
}
// table alias related.
if ds.TableAsName == nil {
h.HashByte(base2.NilFlag)
} else {
h.HashByte(base2.NotNilFlag)
h.HashInt(len(ds.TableAsName.L))
h.HashString(ds.TableAsName.L)
}
// visible columns related.
h.HashInt(len(ds.Columns))
for _, oneCol := range ds.Columns {
h.HashInt64(oneCol.ID)
}
// conditions related.
h.HashInt(len(ds.PushedDownConds))
for _, oneCond := range ds.PushedDownConds {
oneCond.Hash64(h)
}
h.HashInt(len(ds.AllConds))
for _, oneCond := range ds.AllConds {
oneCond.Hash64(h)
}
// hint and update misc.
h.HashInt(ds.PreferStoreType)
h.HashBool(ds.IsForUpdateRead)
}
// Equals implements base.HashEquals interface.
func (ds *DataSource) Equals(other any) bool {
if other == nil {
return false
}
var ds2 *DataSource
switch x := other.(type) {
case *DataSource:
ds2 = x
case DataSource:
ds2 = &x
default:
return false
}
ok := ds.TableInfo.ID == ds2.TableInfo.ID && ds.TableAsName.L == ds2.TableAsName.L && len(ds.PushedDownConds) == len(ds2.PushedDownConds) &&
len(ds.AllConds) == len(ds2.AllConds) && ds.PreferStoreType == ds2.PreferStoreType && ds.IsForUpdateRead == ds2.IsForUpdateRead
if !ok {
return false
}
for i, oneCond := range ds.PushedDownConds {
oneCond.Equals(ds2.PushedDownConds[i])
}
for i, oneCond := range ds.AllConds {
oneCond.Equals(ds2.AllConds[i])
}
return true
}
// *************************** start implementation of Plan interface ***************************
// ExplainInfo implements Plan interface.

View File

@ -33,7 +33,7 @@ import (
// LogicalExpand represents a logical Expand OP serves for data replication requirement.
type LogicalExpand struct {
LogicalSchemaProducer
LogicalSchemaProducer `hash64-equals:"true"`
// distinct group by columns. (maybe projected below if it's a non-col)
DistinctGroupByCol []*expression.Column `hash64-equals:"true"`

View File

@ -93,7 +93,7 @@ func (tp JoinType) String() string {
// LogicalJoin is the logical join plan.
type LogicalJoin struct {
LogicalSchemaProducer
LogicalSchemaProducer `hash64-equals:"true"`
JoinType JoinType `hash64-equals:"true"`
Reordered bool

View File

@ -18,11 +18,14 @@ import (
"math"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/planner/util/optimizetrace"
"github.com/pingcap/tidb/pkg/planner/util/optimizetrace/logicaltrace"
"github.com/pingcap/tidb/pkg/types"
)
var _ base.HashEquals = &LogicalSchemaProducer{}
// LogicalSchemaProducer stores the schema for the logical plans who can produce schema directly.
type LogicalSchemaProducer struct {
schema *expression.Schema
@ -30,6 +33,50 @@ type LogicalSchemaProducer struct {
BaseLogicalPlan
}
// Hash64 implements HashEquals interface.
func (s *LogicalSchemaProducer) Hash64(h base.Hasher) {
// output columns should affect the logical operator's hash.
// since tidb doesn't maintain the names strictly, we should
// only use the schema unique id to distinguish them.
if s.schema != nil {
h.HashByte(base.NotNilFlag)
for _, col := range s.schema.Columns {
col.Hash64(h)
}
} else {
h.HashByte(base.NilFlag)
}
}
// Equals implement HashEquals interface.
func (s *LogicalSchemaProducer) Equals(other any) bool {
s2, ok := other.(*LogicalSchemaProducer)
if !ok {
return false
}
if s == nil {
return s2 == nil
}
if s2 == nil {
return false
}
if s.schema == nil {
return s2.schema == nil
}
if s2.schema == nil {
return false
}
if s.schema.Len() != s2.schema.Len() {
return false
}
for i, col := range s.schema.Columns {
if !col.Equals(s2.schema.Columns[i]) {
return false
}
}
return true
}
// Schema implements the Plan.Schema interface.
func (s *LogicalSchemaProducer) Schema() *expression.Schema {
if s.schema == nil {

View File

@ -8,7 +8,7 @@ go_test(
"logical_mem_table_predicate_extractor_test.go",
],
flaky = True,
shard_count = 19,
shard_count = 20,
deps = [
"//pkg/domain",
"//pkg/expression",

View File

@ -29,6 +29,37 @@ import (
"github.com/stretchr/testify/require"
)
func TestLogicalSchemaProducerHash64Equals(t *testing.T) {
col1 := &expression.Column{
ID: 1,
Index: 0,
RetType: types.NewFieldType(mysql.TypeLonglong),
}
col2 := &expression.Column{
ID: 2,
Index: 0,
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
d1 := logicalop.DataSource{}.Init(ctx, 1)
d1.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col1}})
d2 := logicalop.DataSource{}.Init(ctx, 1)
d2.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col1}})
hasher1 := base.NewHashEqualer()
hasher2 := base.NewHashEqualer()
d1.Hash64(hasher1)
d2.Hash64(hasher2)
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, d1.Equals(d2))
d2.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col2}})
hasher2.Reset()
d2.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, d1.Equals(d2))
}
func TestLogicalMaxOneRowHash64Equals(t *testing.T) {
m1 := &logicalop.LogicalMaxOneRow{}
m2 := &logicalop.LogicalMaxOneRow{}