diff --git a/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go b/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go index a36637358c..3237a717a3 100644 --- a/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go +++ b/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go @@ -37,7 +37,8 @@ import ( // If a field is not tagged, then it will be skipped. func GenHash64Equals4LogicalOps() ([]byte, error) { var structures = []any{logicalop.LogicalJoin{}, logicalop.LogicalAggregation{}, logicalop.LogicalApply{}, - logicalop.LogicalExpand{}, logicalop.LogicalLimit{}, logicalop.LogicalMaxOneRow{}, logicalop.DataSource{}} + logicalop.LogicalExpand{}, logicalop.LogicalLimit{}, logicalop.LogicalMaxOneRow{}, logicalop.DataSource{}, + logicalop.LogicalMemTable{}, logicalop.LogicalUnionAll{}, logicalop.LogicalPartitionUnionAll{}} c := new(cc) c.write(codeGenHash64EqualsPrefix) for _, s := range structures { @@ -125,6 +126,13 @@ func logicalOpName2PlanCodecString(name string) string { return "plancodec.TypeMaxOneRow" case "DataSource": return "plancodec.TypeDataSource" + case "LogicalMemTable": + return "plancodec.TypeMemTableScan" + case "LogicalUnionAll": + return "plancodec.TypeUnion" + case "LogicalPartitionUnionAll": + return "plancodec.TypePartitionUnion" + default: return "" } diff --git a/pkg/planner/core/operator/logicalop/hash64_equals_generated.go b/pkg/planner/core/operator/logicalop/hash64_equals_generated.go index 9f11fa836a..a1e3345bc1 100644 --- a/pkg/planner/core/operator/logicalop/hash64_equals_generated.go +++ b/pkg/planner/core/operator/logicalop/hash64_equals_generated.go @@ -403,6 +403,7 @@ func (op *LogicalExpand) Equals(other any) bool { // Hash64 implements the Hash64Equals interface. func (op *LogicalLimit) Hash64(h base.Hasher) { h.HashString(plancodec.TypeLimit) + op.LogicalSchemaProducer.Hash64(h) if op.PartitionBy == nil { h.HashByte(base.NilFlag) } else { @@ -428,6 +429,9 @@ func (op *LogicalLimit) Equals(other any) bool { if op2 == nil { return false } + if !op.LogicalSchemaProducer.Equals(&op2.LogicalSchemaProducer) { + return false + } if (op.PartitionBy == nil && op2.PartitionBy != nil) || (op.PartitionBy != nil && op2.PartitionBy == nil) || len(op.PartitionBy) != len(op2.PartitionBy) { return false } @@ -549,3 +553,88 @@ func (op *DataSource) Equals(other any) bool { } return true } + +// Hash64 implements the Hash64Equals interface. +func (op *LogicalMemTable) Hash64(h base.Hasher) { + h.HashString(plancodec.TypeMemTableScan) + op.LogicalSchemaProducer.Hash64(h) + op.DBName.Hash64(h) + if op.TableInfo == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + op.TableInfo.Hash64(h) + } +} + +// Equals implements the Hash64Equals interface, only receive *LogicalMemTable pointer. +func (op *LogicalMemTable) Equals(other any) bool { + op2, ok := other.(*LogicalMemTable) + if !ok { + return false + } + if op == nil { + return op2 == nil + } + if op2 == nil { + return false + } + if !op.LogicalSchemaProducer.Equals(&op2.LogicalSchemaProducer) { + return false + } + if !op.DBName.Equals(&op2.DBName) { + return false + } + if !op.TableInfo.Equals(op2.TableInfo) { + return false + } + return true +} + +// Hash64 implements the Hash64Equals interface. +func (op *LogicalUnionAll) Hash64(h base.Hasher) { + h.HashString(plancodec.TypeUnion) + op.LogicalSchemaProducer.Hash64(h) +} + +// Equals implements the Hash64Equals interface, only receive *LogicalUnionAll pointer. +func (op *LogicalUnionAll) Equals(other any) bool { + op2, ok := other.(*LogicalUnionAll) + if !ok { + return false + } + if op == nil { + return op2 == nil + } + if op2 == nil { + return false + } + if !op.LogicalSchemaProducer.Equals(&op2.LogicalSchemaProducer) { + return false + } + return true +} + +// Hash64 implements the Hash64Equals interface. +func (op *LogicalPartitionUnionAll) Hash64(h base.Hasher) { + h.HashString(plancodec.TypePartitionUnion) + op.LogicalUnionAll.Hash64(h) +} + +// Equals implements the Hash64Equals interface, only receive *LogicalPartitionUnionAll pointer. +func (op *LogicalPartitionUnionAll) Equals(other any) bool { + op2, ok := other.(*LogicalPartitionUnionAll) + if !ok { + return false + } + if op == nil { + return op2 == nil + } + if op2 == nil { + return false + } + if !op.LogicalUnionAll.Equals(&op2.LogicalUnionAll) { + return false + } + return true +} diff --git a/pkg/planner/core/operator/logicalop/logical_limit.go b/pkg/planner/core/operator/logicalop/logical_limit.go index 5e61a4e949..f2a56a63fc 100644 --- a/pkg/planner/core/operator/logicalop/logical_limit.go +++ b/pkg/planner/core/operator/logicalop/logical_limit.go @@ -30,7 +30,7 @@ import ( // LogicalLimit represents offset and limit plan. type LogicalLimit struct { - LogicalSchemaProducer + LogicalSchemaProducer `hash64-equals:"true"` PartitionBy []property.SortItem `hash64-equals:"true"` // This is used for enhanced topN optimization Offset uint64 `hash64-equals:"true"` diff --git a/pkg/planner/core/operator/logicalop/logical_mem_table.go b/pkg/planner/core/operator/logicalop/logical_mem_table.go index 0648b114a0..7e8898aa39 100644 --- a/pkg/planner/core/operator/logicalop/logical_mem_table.go +++ b/pkg/planner/core/operator/logicalop/logical_mem_table.go @@ -40,11 +40,11 @@ import ( // requesting all cluster components log search gRPC interface to retrieve // log message and filtering them in TiDB node. type LogicalMemTable struct { - LogicalSchemaProducer + LogicalSchemaProducer `hash64-equals:"true"` Extractor base.MemTablePredicateExtractor - DBName pmodel.CIStr - TableInfo *model.TableInfo + DBName pmodel.CIStr `hash64-equals:"true"` + TableInfo *model.TableInfo `hash64-equals:"true"` Columns []*model.ColumnInfo // QueryTimeRange is used to specify the time range for metrics summary tables and inspection tables // e.g: select /*+ time_range('2020-02-02 12:10:00', '2020-02-02 13:00:00') */ from metrics_summary; diff --git a/pkg/planner/core/operator/logicalop/logical_partition_union_all.go b/pkg/planner/core/operator/logicalop/logical_partition_union_all.go index 095abee2bd..1bf0b1b52d 100644 --- a/pkg/planner/core/operator/logicalop/logical_partition_union_all.go +++ b/pkg/planner/core/operator/logicalop/logical_partition_union_all.go @@ -23,7 +23,7 @@ import ( // LogicalPartitionUnionAll represents the LogicalUnionAll plan is for partition table. type LogicalPartitionUnionAll struct { - LogicalUnionAll + LogicalUnionAll `hash64-equals:"true"` } // Init initializes LogicalPartitionUnionAll. diff --git a/pkg/planner/core/operator/logicalop/logical_union_all.go b/pkg/planner/core/operator/logicalop/logical_union_all.go index 385e987e19..b15487a7e7 100644 --- a/pkg/planner/core/operator/logicalop/logical_union_all.go +++ b/pkg/planner/core/operator/logicalop/logical_union_all.go @@ -29,7 +29,7 @@ import ( // LogicalUnionAll represents LogicalUnionAll plan. type LogicalUnionAll struct { - LogicalSchemaProducer + LogicalSchemaProducer `hash64-equals:"true"` } // Init initializes LogicalUnionAll. diff --git a/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel b/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel index 32d77841d6..6efece856d 100644 --- a/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel +++ b/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel @@ -8,13 +8,15 @@ go_test( "logical_mem_table_predicate_extractor_test.go", ], flaky = True, - shard_count = 20, + shard_count = 22, deps = [ "//pkg/domain", "//pkg/expression", "//pkg/expression/aggregation", + "//pkg/meta/model", "//pkg/parser", "//pkg/parser/ast", + "//pkg/parser/model", "//pkg/parser/mysql", "//pkg/planner", "//pkg/planner/cascades/base", diff --git a/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go b/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go index 81a438e6f1..701495c79c 100644 --- a/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go +++ b/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go @@ -19,7 +19,9 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/expression/aggregation" + "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser/ast" + pmodel "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" @@ -29,6 +31,117 @@ import ( "github.com/stretchr/testify/require" ) +func TestLogicalUnionAllHash64Equals(t *testing.T) { + col1 := &expression.Column{ + ID: 1, + Index: 0, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + col2 := &expression.Column{ + ID: 2, + Index: 0, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + // test schema producer. + ctx := mock.NewContext() + u1 := logicalop.LogicalUnionAll{}.Init(ctx, 1) + u1.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col1}}) + u2 := logicalop.LogicalUnionAll{}.Init(ctx, 1) + u2.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col1}}) + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + u1.Hash64(hasher1) + u2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, u1.Equals(u2)) + + u2.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col2}}) + hasher2.Reset() + u2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, u1.Equals(u2)) + + pu1 := logicalop.LogicalPartitionUnionAll{}.Init(ctx, 1) + pu1.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col1}}) + pu2 := logicalop.LogicalPartitionUnionAll{}.Init(ctx, 1) + pu2.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col1}}) + hasher1.Reset() + hasher2.Reset() + pu1.Hash64(hasher1) + pu2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, pu1.Equals(pu2)) + + pu2.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col2}}) + hasher2.Reset() + pu2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, pu1.Equals(pu2)) +} + +func TestLogicalMemTableHash64Equals(t *testing.T) { + col1 := &expression.Column{ + ID: 1, + Index: 0, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + col2 := &expression.Column{ + ID: 2, + Index: 0, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + // test schema producer. + ctx := mock.NewContext() + m1 := logicalop.LogicalMemTable{}.Init(ctx, 1) + m1.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col1}}) + m2 := logicalop.LogicalMemTable{}.Init(ctx, 1) + m2.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col1}}) + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + m1.Hash64(hasher1) + m2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, m1.Equals(m2)) + + m2.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col2}}) + hasher2.Reset() + m2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, m1.Equals(m2)) + + m2.LogicalSchemaProducer.SetSchema(&expression.Schema{Columns: []*expression.Column{col1}}) + m2.DBName = pmodel.NewCIStr("d1") + hasher2.Reset() + m2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, m1.Equals(m2)) + + m2.DBName = pmodel.NewCIStr("") + m2.TableInfo = &model.TableInfo{ID: 1} + hasher2.Reset() + m2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, m1.Equals(m2)) + + m2.TableInfo = &model.TableInfo{} + hasher2.Reset() + m2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, m1.Equals(m2)) + + m2.TableInfo = nil + hasher2.Reset() + m2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + require.True(t, m1.Equals(m2)) + + m1.TableInfo = &model.TableInfo{ID: 1} + hasher1.Reset() + m1.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + require.False(t, m1.Equals(m2)) +} + func TestLogicalSchemaProducerHash64Equals(t *testing.T) { col1 := &expression.Column{ ID: 1,