diff --git a/pkg/ttl/cache/table.go b/pkg/ttl/cache/table.go index 4e726158c4..2a8d46b55f 100644 --- a/pkg/ttl/cache/table.go +++ b/pkg/ttl/cache/table.go @@ -185,28 +185,101 @@ func (t *PhysicalTable) ValidateKeyPrefix(key []types.Datum) error { return nil } -// EvalExpireTime returns the expired time +var mockExpireTimeKey struct{} + +// SetMockExpireTime can only used in test +func SetMockExpireTime(ctx context.Context, tm time.Time) context.Context { + return context.WithValue(ctx, mockExpireTimeKey, tm) +} + +// EvalExpireTime returns the expired time. +// It uses the global timezone to compute the expired time. +// Then we'll reset the returned expired time to the same timezone as the input `now`. func (t *PhysicalTable) EvalExpireTime(ctx context.Context, se session.Session, now time.Time) (expire time.Time, err error) { - tz := se.GetSessionVars().Location() + if intest.InTest { + if tm, ok := ctx.Value(mockExpireTimeKey).(time.Time); ok { + return tm, err + } + } expireExpr := t.TTLInfo.IntervalExprStr unit := ast.TimeUnitType(t.TTLInfo.IntervalTimeUnit) + // Use the global time zone to compute expire time. + // Different timezones may have different results event with the same "now" time and TTL expression. + // Consider a TTL setting with the expiration `INTERVAL 1 MONTH`. + // If the current timezone is `Asia/Shanghai` and now is `2021-03-01 00:00:00 +0800` + // the expired time should be `2021-02-01 00:00:00 +0800`, corresponding to UTC time `2021-01-31 16:00:00 UTC`. + // But if we use the `UTC` time zone, the current time is `2021-02-28 16:00:00 UTC`, + // and the expired time should be `2021-01-28 16:00:00 UTC` that is not the same the previous one. + globalTz, err := se.GlobalTimeZone(ctx) + if err != nil { + return time.Time{}, err + } + var rows []chunk.Row + + // We should set the session time zone to UTC because the next SQLs should be executed in the UTC timezone. + // The session time zone should be reverted to the original one after the SQLs are executed. + rows, err = se.ExecuteSQL(ctx, "SELECT @@time_zone") + if err != nil { + return + } + + originalTZ := rows[0].GetString(0) + if _, err = se.ExecuteSQL(ctx, "SET @@time_zone='UTC'"); err != nil { + return + } + + defer func() { + _, restoreErr := se.ExecuteSQL(ctx, "SET @@time_zone=%?", originalTZ) + if err == nil { + err = restoreErr + } + }() + + // Firstly, we should use the UTC time zone to compute the expired time to avoid time shift caused by DST. + // The start time should be a time with the same datetime string as `now` but it is in the UTC timezone. + // For example, if global timezone is `Asia/Shanghai` with a string format `2020-01-01 08:00:00 +0800`. + // The startTime should be in timezone `UTC` and have a string format `2020-01-01 08:00:00 +0000` which is not the + // same as the original one (`2020-01-01 00:00:00 +0000` in UTC actually). + nowInGlobalTZ := now.In(globalTz) + startTime := time.Date( + nowInGlobalTZ.Year(), nowInGlobalTZ.Month(), nowInGlobalTZ.Day(), + nowInGlobalTZ.Hour(), nowInGlobalTZ.Minute(), nowInGlobalTZ.Second(), + nowInGlobalTZ.Nanosecond(), time.UTC, + ) + rows, err = se.ExecuteSQL( ctx, // FROM_UNIXTIME does not support negative value, so we use `FROM_UNIXTIME(0) + INTERVAL ` // to present current time - fmt.Sprintf("SELECT FROM_UNIXTIME(0) + INTERVAL %d SECOND - INTERVAL %s %s", now.Unix(), expireExpr, unit.String()), + fmt.Sprintf("SELECT FROM_UNIXTIME(0) + INTERVAL %d MICROSECOND - INTERVAL %s %s", + startTime.UnixMicro(), + expireExpr, + unit.String(), + ), ) if err != nil { return } - tm := rows[0].GetTime(0) - return tm.CoreTime().GoTime(tz) + tm, err := rows[0].GetTime(0).GoTime(time.UTC) + if err != nil { + return + } + + // Then we should add the duration between the time get from the previous SQL and the start time to the now time. + expiredTime := nowInGlobalTZ. + In(now.Location()). + Add(tm.Sub(startTime)). + // Truncate to second to make sure the precision is always the same with the one stored in a table to avoid some + // comparing problems in testing. + Truncate(time.Second) + + return expiredTime, nil } // SplitScanRanges split ranges for TTL scan diff --git a/pkg/ttl/cache/table_test.go b/pkg/ttl/cache/table_test.go index ed21907fb2..2a9cc36443 100644 --- a/pkg/ttl/cache/table_test.go +++ b/pkg/ttl/cache/table_test.go @@ -167,6 +167,7 @@ func TestEvalTTLExpireTime(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("create table test.t(a int, t datetime) ttl = `t` + interval 1 day") tk.MustExec("create table test.t2(a int, t datetime) ttl = `t` + interval 3 month") + tk.MustExec("set @@time_zone='Asia/Tokyo'") tb, err := do.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) require.NoError(t, err) @@ -188,29 +189,114 @@ func TestEvalTTLExpireTime(t *testing.T) { tz2, err := time.LoadLocation("Europe/Berlin") require.NoError(t, err) - se.GetSessionVars().TimeZone = tz1 + _, err = se.ExecuteSQL(context.TODO(), "SET @@global.time_zone = 'Asia/Shanghai'") + require.NoError(t, err) tm, err := ttlTbl.EvalExpireTime(context.TODO(), se, now) require.NoError(t, err) require.Equal(t, now.Add(-time.Hour*24).Unix(), tm.Unix()) - require.Equal(t, "1969-12-31 08:00:00", tm.Format(time.DateTime)) - require.Equal(t, tz1.String(), tm.Location().String()) + require.Equal(t, "1969-12-31 08:00:00", tm.In(tz1).Format(time.DateTime)) + require.Same(t, now.Location(), tm.Location()) - se.GetSessionVars().TimeZone = tz2 + _, err = se.ExecuteSQL(context.TODO(), "SET @@global.time_zone = 'Europe/Berlin'") + require.NoError(t, err) tm, err = ttlTbl.EvalExpireTime(context.TODO(), se, now) require.NoError(t, err) require.Equal(t, now.Add(-time.Hour*24).Unix(), tm.Unix()) - require.Equal(t, "1969-12-31 01:00:00", tm.Format(time.DateTime)) - require.Equal(t, tz2.String(), tm.Location().String()) + require.Equal(t, "1969-12-31 01:00:00", tm.In(tz2).Format(time.DateTime)) + require.Same(t, now.Location(), tm.Location()) - se.GetSessionVars().TimeZone = tz1 + _, err = se.ExecuteSQL(context.TODO(), "SET @@global.time_zone = 'Asia/Shanghai'") + require.NoError(t, err) tm, err = ttlTbl2.EvalExpireTime(context.TODO(), se, now) require.NoError(t, err) - require.Equal(t, "1969-10-01 08:00:00", tm.Format(time.DateTime)) - require.Equal(t, tz1.String(), tm.Location().String()) + require.Equal(t, "1969-10-01 08:00:00", tm.In(tz1).Format(time.DateTime)) + require.Same(t, now.Location(), tm.Location()) - se.GetSessionVars().TimeZone = tz2 + _, err = se.ExecuteSQL(context.TODO(), "SET @@global.time_zone = 'Europe/Berlin'") + require.NoError(t, err) tm, err = ttlTbl2.EvalExpireTime(context.TODO(), se, now) require.NoError(t, err) - require.Equal(t, "1969-10-01 01:00:00", tm.Format(time.DateTime)) - require.Equal(t, tz2.String(), tm.Location().String()) + require.Equal(t, "1969-10-01 01:00:00", tm.In(tz2).Format(time.DateTime)) + require.Same(t, now.Location(), tm.Location()) + + // test cases for daylight saving time. + // When local standard time was about to reach Sunday, 10 March 2024, 02:00:00 clocks were turned forward 1 hour to + // Sunday, 10 March 2024, 03:00:00 local daylight time instead. + tz3, err := time.LoadLocation("America/Los_Angeles") + require.NoError(t, err) + now, err = time.ParseInLocation(time.DateTime, "2024-03-11 19:49:59", tz3) + require.NoError(t, err) + _, err = se.ExecuteSQL(context.TODO(), "SET @@global.time_zone = 'America/Los_Angeles'") + require.NoError(t, err) + tk.MustExec("create table test.t3(a int, t datetime) ttl = `t` + interval 90 minute") + tb3, err := do.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t3")) + require.NoError(t, err) + tblInfo3 := tb3.Meta() + ttlTbl3, err := cache.NewPhysicalTable(model.NewCIStr("test"), tblInfo3, model.NewCIStr("")) + require.NoError(t, err) + require.NoError(t, err) + tm, err = ttlTbl3.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, "2024-03-11 18:19:59", tm.In(tz3).Format(time.DateTime)) + require.Same(t, now.Location(), tm.Location()) + + // across day light-saving time + now, err = time.ParseInLocation(time.DateTime, "2024-03-10 03:01:00", tz3) + require.NoError(t, err) + tm, err = ttlTbl3.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, "2024-03-10 00:31:00", tm.In(tz3).Format(time.DateTime)) + require.Same(t, now.Location(), tm.Location()) + + now, err = time.ParseInLocation(time.DateTime, "2024-03-10 04:01:00", tz3) + require.NoError(t, err) + tm, err = ttlTbl3.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, "2024-03-10 01:31:00", tm.In(tz3).Format(time.DateTime)) + require.Same(t, now.Location(), tm.Location()) + + now, err = time.ParseInLocation(time.DateTime, "2024-11-03 03:00:00", tz3) + require.NoError(t, err) + tm, err = ttlTbl3.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, "2024-11-03 01:30:00", tm.In(tz3).Format(time.DateTime)) + require.Same(t, now.Location(), tm.Location()) + // 2024-11-03 01:30:00 in America/Los_Angeles has two related time points: + // 2024-11-03 01:30:00 -0700 PDT + // 2024-11-03 01:30:00 -0800 PST + // We must use the earlier one to avoid deleting some unexpected rows. + require.Equal(t, int64(5400), now.Unix()-tm.Unix()) + + // we should use global time zone to calculate the expired time + _, err = se.ExecuteSQL(context.TODO(), "SET @@global.time_zone = 'Asia/Shanghai'") + require.NoError(t, err) + now, err = time.ParseInLocation(time.DateTime, "1999-02-28 16:00:00", time.UTC) + require.NoError(t, err) + tm, err = ttlTbl2.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, "1998-12-01 00:00:00", tm.In(tz1).Format(time.DateTime)) + require.Same(t, time.UTC, tm.Location()) + + // time should be truncated to second to make the result simple + now, err = time.ParseInLocation("2006-01-02 15:04:05.000000", "2023-01-02 15:00:01.986542", time.UTC) + require.NoError(t, err) + tm, err = ttlTbl.EvalExpireTime(context.TODO(), se, now) + require.NoError(t, err) + require.Equal(t, "2023-01-01 15:00:01.000000", tm.Format("2006-01-02 15:04:05.000000")) + require.Same(t, time.UTC, tm.Location()) + + // test for string interval format + tk.MustExec("create table test.t4(a int, t datetime) ttl = `t` + interval '1:3' hour_minute") + tb4, err := do.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t4")) + require.NoError(t, err) + tblInfo4 := tb4.Meta() + ttlTbl4, err := cache.NewPhysicalTable(model.NewCIStr("test"), tblInfo4, model.NewCIStr("")) + require.NoError(t, err) + tm, err = ttlTbl4.EvalExpireTime(context.TODO(), se, time.Unix(0, 0).In(tz2)) + require.NoError(t, err) + require.Equal(t, "1969-12-31 22:57:00", tm.In(time.UTC).Format(time.DateTime)) + require.Same(t, tz2, tm.Location()) + + // session time zone should keep unchanged + tk.MustQuery("select @@time_zone").Check(testkit.Rows("Asia/Tokyo")) } diff --git a/pkg/ttl/session/BUILD.bazel b/pkg/ttl/session/BUILD.bazel index 34988e64f3..b7abb2e2b6 100644 --- a/pkg/ttl/session/BUILD.bazel +++ b/pkg/ttl/session/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "//pkg/ttl/metrics", "//pkg/util/chunk", "//pkg/util/sqlexec", + "//pkg/util/timeutil", "@com_github_pingcap_errors//:errors", ], ) diff --git a/pkg/ttl/session/session.go b/pkg/ttl/session/session.go index 0d1248a6e9..c7e74a3b20 100644 --- a/pkg/ttl/session/session.go +++ b/pkg/ttl/session/session.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/pkg/ttl/metrics" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/timeutil" ) // TxnMode represents using optimistic or pessimistic mode in the transaction @@ -51,6 +52,8 @@ type Session interface { RunInTxn(ctx context.Context, fn func() error, mode TxnMode) (err error) // ResetWithGlobalTimeZone resets the session time zone to global time zone ResetWithGlobalTimeZone(ctx context.Context) error + // GlobalTimeZone returns the global timezone. It is used to compute expire time for TTL + GlobalTimeZone(ctx context.Context) (*time.Location, error) // Close closes the session Close() // Now returns the current time in location specified by session var @@ -168,6 +171,15 @@ func (s *session) ResetWithGlobalTimeZone(ctx context.Context) error { return err } +// GlobalTimeZone returns the global timezone +func (s *session) GlobalTimeZone(ctx context.Context) (*time.Location, error) { + str, err := s.GetSessionVars().GetGlobalSystemVar(ctx, "time_zone") + if err != nil { + return nil, err + } + return timeutil.ParseTimeZone(str) +} + // Close closes the session func (s *session) Close() { if s.closeFn != nil { diff --git a/pkg/ttl/ttlworker/BUILD.bazel b/pkg/ttl/ttlworker/BUILD.bazel index 1415351100..5d36b6a909 100644 --- a/pkg/ttl/ttlworker/BUILD.bazel +++ b/pkg/ttl/ttlworker/BUILD.bazel @@ -36,6 +36,7 @@ go_library( "//pkg/types", "//pkg/util", "//pkg/util/chunk", + "//pkg/util/intest", "//pkg/util/logutil", "//pkg/util/sqlexec", "//pkg/util/timeutil", @@ -69,7 +70,7 @@ go_test( embed = [":ttlworker"], flaky = True, race = "on", - shard_count = 48, + shard_count = 50, deps = [ "//pkg/domain", "//pkg/infoschema", diff --git a/pkg/ttl/ttlworker/job.go b/pkg/ttl/ttlworker/job.go index 40e2ab7293..0f59a23244 100644 --- a/pkg/ttl/ttlworker/job.go +++ b/pkg/ttl/ttlworker/job.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/ttl/cache" "github.com/pingcap/tidb/pkg/ttl/session" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" ) @@ -127,6 +128,7 @@ type ttlJob struct { // finish turns current job into last job, and update the error message and statistics summary func (job *ttlJob) finish(se session.Session, now time.Time, summary *TTLSummary) { + intest.Assert(se.GetSessionVars().Location().String() == now.Location().String()) // at this time, the job.ctx may have been canceled (to cancel this job) // even when it's canceled, we'll need to update the states, so use another context err := se.RunInTxn(context.TODO(), func() error { diff --git a/pkg/ttl/ttlworker/job_manager.go b/pkg/ttl/ttlworker/job_manager.go index be46a6817a..59f2bab360 100644 --- a/pkg/ttl/ttlworker/job_manager.go +++ b/pkg/ttl/ttlworker/job_manager.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/pkg/ttl/client" "github.com/pingcap/tidb/pkg/ttl/metrics" "github.com/pingcap/tidb/pkg/ttl/session" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/timeutil" clientv3 "go.etcd.io/etcd/client/v3" @@ -201,7 +202,7 @@ func (m *JobManager) jobLoop() error { case <-m.ctx.Done(): return nil case <-timerTicker: - m.onTimerTick(se, timerRT, timerSyncer, time.Now()) + m.onTimerTick(se, timerRT, timerSyncer, now) case jobReq := <-jobRequestCh: m.handleSubmitJobRequest(se, jobReq) case <-infoSchemaCacheUpdateTicker: @@ -327,7 +328,7 @@ func (m *JobManager) handleSubmitJobRequest(se session.Session, jobReq *SubmitTT return } - _, err := m.lockNewJob(m.ctx, se, tbl, time.Now(), jobReq.RequestID, false) + _, err := m.lockNewJob(m.ctx, se, tbl, se.Now(), jobReq.RequestID, false) jobReq.RespCh <- err } @@ -359,7 +360,13 @@ func (m *JobManager) triggerTTLJob(requestID string, cmd *client.TriggerNewTTLJo return } - if !timeutil.WithinDayTimePeriod(variable.TTLJobScheduleWindowStartTime.Load(), variable.TTLJobScheduleWindowEndTime.Load(), time.Now()) { + tz, err := se.GlobalTimeZone(m.ctx) + if err != nil { + responseErr(err) + return + } + + if !timeutil.WithinDayTimePeriod(variable.TTLJobScheduleWindowStartTime.Load(), variable.TTLJobScheduleWindowEndTime.Load(), se.Now().In(tz)) { responseErr(errors.New("not in TTL job window")) return } @@ -558,6 +565,13 @@ j: } func (m *JobManager) rescheduleJobs(se session.Session, now time.Time) { + tz, err := se.GlobalTimeZone(m.ctx) + if err != nil { + terror.Log(err) + } else { + now = now.In(tz) + } + if !variable.EnableTTLJob.Load() || !timeutil.WithinDayTimePeriod(variable.TTLJobScheduleWindowStartTime.Load(), variable.TTLJobScheduleWindowEndTime.Load(), now) { if len(m.runningJobs) > 0 { for _, job := range m.runningJobs { @@ -705,6 +719,7 @@ func (m *JobManager) lockHBTimeoutJob(ctx context.Context, se session.Session, t jobID = tableStatus.CurrentJobID jobStart = tableStatus.CurrentJobStartTime expireTime = tableStatus.CurrentJobTTLExpire + intest.Assert(se.GetSessionVars().TimeZone.String() == now.Location().String()) sql, args := setTableStatusOwnerSQL(tableStatus.CurrentJobID, table.ID, jobStart, now, expireTime, m.id) if _, err = se.ExecuteSQL(ctx, sql, args...); err != nil { return errors.Wrapf(err, "execute sql: %s", sql) @@ -737,6 +752,9 @@ func (m *JobManager) lockNewJob(ctx context.Context, se session.Session, table * return err } + intest.Assert(se.GetSessionVars().TimeZone.String() == now.Location().String()) + intest.Assert(se.GetSessionVars().TimeZone.String() == expireTime.Location().String()) + sql, args := setTableStatusOwnerSQL(jobID, table.ID, now, now, expireTime, m.id) _, err = se.ExecuteSQL(ctx, sql, args...) if err != nil { @@ -859,6 +877,7 @@ func (m *JobManager) updateHeartBeat(ctx context.Context, se session.Session, no continue } + intest.Assert(se.GetSessionVars().TimeZone.String() == now.Location().String()) sql, args := updateHeartBeatSQL(job.tbl.ID, now, m.id) _, err := se.ExecuteSQL(ctx, sql, args...) if err != nil { @@ -1259,3 +1278,17 @@ func (a *managerJobAdapter) GetJob(ctx context.Context, tableID, physicalID int6 return &jobTrace, nil } + +func (a *managerJobAdapter) Now() (time.Time, error) { + se, err := getSession(a.sessPool) + if err != nil { + return time.Time{}, err + } + + tz, err := se.GlobalTimeZone(context.TODO()) + if err != nil { + return time.Time{}, err + } + + return se.Now().In(tz), nil +} diff --git a/pkg/ttl/ttlworker/job_manager_integration_test.go b/pkg/ttl/ttlworker/job_manager_integration_test.go index 59d3fd90bf..06d21fe2e4 100644 --- a/pkg/ttl/ttlworker/job_manager_integration_test.go +++ b/pkg/ttl/ttlworker/job_manager_integration_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/google/uuid" + "github.com/ngaut/pools" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/domain" @@ -63,6 +64,45 @@ func sessionFactory(t *testing.T, store kv.Storage) func() session.Session { } } +func TestGetSession(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@time_zone = 'Asia/Shanghai'") + tk.MustExec("set @@global.time_zone= 'Europe/Berlin'") + tk.MustExec("set @@tidb_retry_limit=1") + tk.MustExec("set @@tidb_enable_1pc=0") + tk.MustExec("set @@tidb_enable_async_commit=0") + var getCnt atomic.Int32 + + pool := pools.NewResourcePool(func() (pools.Resource, error) { + if getCnt.CompareAndSwap(0, 1) { + return tk.Session(), nil + } + require.FailNow(t, "get session more than once") + return nil, nil + }, 1, 1, 0) + defer pool.Close() + + se, err := ttlworker.GetSessionForTest(pool) + require.NoError(t, err) + defer se.Close() + + // global time zone should not change + tk.MustQuery("select @@global.time_zone").Check(testkit.Rows("Europe/Berlin")) + tz, err := se.GlobalTimeZone(context.TODO()) + require.NoError(t, err) + require.Equal(t, "Europe/Berlin", tz.String()) + + // session variables should be set + tk.MustQuery("select @@time_zone, @@tidb_retry_limit, @@tidb_enable_1pc, @@tidb_enable_async_commit"). + Check(testkit.Rows("UTC 0 1 1")) + + // all session variables should be restored after close + se.Close() + tk.MustQuery("select @@time_zone, @@tidb_retry_limit, @@tidb_enable_1pc, @@tidb_enable_async_commit"). + Check(testkit.Rows("Asia/Shanghai 1 0 0")) +} + func TestParallelLockNewJob(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) waitAndStopTTLManager(t, dom) @@ -75,14 +115,14 @@ func TestParallelLockNewJob(t *testing.T) { m.InfoSchemaCache().Tables[testTable.ID] = testTable se := sessionFactory() - job, err := m.LockJob(context.Background(), se, testTable, time.Now(), uuid.NewString(), false) + job, err := m.LockJob(context.Background(), se, testTable, se.Now(), uuid.NewString(), false) require.NoError(t, err) - job.Finish(se, time.Now(), &ttlworker.TTLSummary{}) + job.Finish(se, se.Now(), &ttlworker.TTLSummary{}) // lock one table in parallel, only one of them should lock successfully testTimes := 100 concurrency := 5 - now := time.Now() + now := se.Now() for i := 0; i < testTimes; i++ { successCounter := atomic.NewUint64(0) successJob := &ttlworker.TTLJob{} @@ -111,7 +151,7 @@ func TestParallelLockNewJob(t *testing.T) { wg.Wait() require.Equal(t, uint64(1), successCounter.Load()) - successJob.Finish(se, time.Now(), &ttlworker.TTLSummary{}) + successJob.Finish(se, se.Now(), &ttlworker.TTLSummary{}) } } @@ -131,7 +171,7 @@ func TestFinishJob(t *testing.T) { m := ttlworker.NewJobManager("test-id", nil, store, nil, nil) m.InfoSchemaCache().Tables[testTable.ID] = testTable se := sessionFactory() - startTime := time.Now() + startTime := se.Now() job, err := m.LockJob(context.Background(), se, testTable, startTime, uuid.NewString(), false) require.NoError(t, err) @@ -156,7 +196,7 @@ func TestFinishJob(t *testing.T) { summary.SummaryText = string(summaryBytes) require.NoError(t, err) - endTime := time.Now() + endTime := se.Now() job.Finish(se, endTime, summary) tk.MustQuery("select table_id, last_job_summary from mysql.tidb_ttl_table_status").Check(testkit.Rows("2 " + summary.SummaryText)) tk.MustQuery("select * from mysql.tidb_ttl_task").Check(testkit.Rows()) @@ -609,7 +649,6 @@ func TestJobTimeout(t *testing.T) { waitAndStopTTLManager(t, dom) - now := time.Now() tk.MustExec("create table test.t (id int, created_at datetime) ttl = `created_at` + interval 1 minute ttl_job_interval = '1m'") table, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) tableID := table.Meta().ID @@ -617,6 +656,7 @@ func TestJobTimeout(t *testing.T) { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnTTL) se := sessionFactory() + now := se.Now() m := ttlworker.NewJobManager("manager-1", nil, store, nil, func() bool { return true }) @@ -1246,3 +1286,20 @@ func TestManagerJobAdapterGetJob(t *testing.T) { tk.MustExec("delete from mysql.tidb_ttl_job_history") } } + +func TestManagerJobAdapterNow(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + adapter := ttlworker.NewManagerJobAdapter(store, dom.SysSessionPool(), nil) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set @@global.time_zone ='Europe/Berlin'") + tk.MustExec("set @@time_zone='Asia/Shanghai'") + + now, err := adapter.Now() + require.NoError(t, err) + localNow := time.Now() + + require.Equal(t, "Europe/Berlin", now.Location().String()) + require.InDelta(t, now.Unix(), localNow.Unix(), 10) +} diff --git a/pkg/ttl/ttlworker/job_manager_test.go b/pkg/ttl/ttlworker/job_manager_test.go index 421fdeb653..4c9aa6c0da 100644 --- a/pkg/ttl/ttlworker/job_manager_test.go +++ b/pkg/ttl/ttlworker/job_manager_test.go @@ -21,6 +21,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" timerapi "github.com/pingcap/tidb/pkg/timer/api" @@ -141,6 +142,11 @@ var updateStatusSQL = "SELECT LOW_PRIORITY table_id,parent_table_id,table_statis // TTLJob exports the ttlJob for test type TTLJob = ttlJob +// GetSessionForTest is used for test +func GetSessionForTest(pool sessionPool) (session.Session, error) { + return getSession(pool) +} + // LockJob is an exported version of lockNewJob for test func (m *JobManager) LockJob(ctx context.Context, se session.Session, table *cache.PhysicalTable, now time.Time, createJobID string, checkInterval bool) (*TTLJob, error) { if createJobID == "" { @@ -227,17 +233,17 @@ func TestReadyForLockHBTimeoutJobTables(t *testing.T) { // table only in the table status cache will not be scheduled {"proper subset", []*cache.PhysicalTable{}, []*cache.TableStatus{{TableID: tbl.ID, ParentTableID: tbl.ID}}, false}, // table whose current job owner id is not empty, and heart beat time is long enough will not be scheduled - {"current job not empty", []*cache.PhysicalTable{tbl}, []*cache.TableStatus{{TableID: tbl.ID, ParentTableID: tbl.ID, CurrentJobID: "job1", CurrentJobOwnerID: "test-another-id", CurrentJobOwnerHBTime: time.Now()}}, false}, + {"current job not empty", []*cache.PhysicalTable{tbl}, []*cache.TableStatus{{TableID: tbl.ID, ParentTableID: tbl.ID, CurrentJobID: "job1", CurrentJobOwnerID: "test-another-id", CurrentJobOwnerHBTime: se.Now()}}, false}, // table whose current job owner id is not empty, but heart beat time is expired will be scheduled - {"hb time expired", []*cache.PhysicalTable{tbl}, []*cache.TableStatus{{TableID: tbl.ID, ParentTableID: tbl.ID, CurrentJobID: "job1", CurrentJobOwnerID: "test-another-id", CurrentJobOwnerHBTime: time.Now().Add(-time.Hour)}}, true}, + {"hb time expired", []*cache.PhysicalTable{tbl}, []*cache.TableStatus{{TableID: tbl.ID, ParentTableID: tbl.ID, CurrentJobID: "job1", CurrentJobOwnerID: "test-another-id", CurrentJobOwnerHBTime: se.Now().Add(-time.Hour)}}, true}, // if the last start time is too near, it will not be scheduled because no job running - {"last start time too near", []*cache.PhysicalTable{tbl}, []*cache.TableStatus{{TableID: tbl.ID, ParentTableID: tbl.ID, LastJobStartTime: time.Now()}}, false}, + {"last start time too near", []*cache.PhysicalTable{tbl}, []*cache.TableStatus{{TableID: tbl.ID, ParentTableID: tbl.ID, LastJobStartTime: se.Now()}}, false}, // if the last start time is expired, it will not be scheduled because no job running - {"last start time expired", []*cache.PhysicalTable{tbl}, []*cache.TableStatus{{TableID: tbl.ID, ParentTableID: tbl.ID, LastJobStartTime: time.Now().Add(-time.Hour * 2)}}, false}, + {"last start time expired", []*cache.PhysicalTable{tbl}, []*cache.TableStatus{{TableID: tbl.ID, ParentTableID: tbl.ID, LastJobStartTime: se.Now().Add(-time.Hour * 2)}}, false}, // if the interval is 24h, and the last start time is near, it will not be scheduled because no job running - {"last start time too near for 24h", []*cache.PhysicalTable{tblWithDailyInterval}, []*cache.TableStatus{{TableID: tblWithDailyInterval.ID, ParentTableID: tblWithDailyInterval.ID, LastJobStartTime: time.Now().Add(-time.Hour * 2)}}, false}, + {"last start time too near for 24h", []*cache.PhysicalTable{tblWithDailyInterval}, []*cache.TableStatus{{TableID: tblWithDailyInterval.ID, ParentTableID: tblWithDailyInterval.ID, LastJobStartTime: se.Now().Add(-time.Hour * 2)}}, false}, // if the interval is 24h, and the last start time is far enough, it will not be scheduled because no job running - {"last start time far enough for 24h", []*cache.PhysicalTable{tblWithDailyInterval}, []*cache.TableStatus{{TableID: tblWithDailyInterval.ID, ParentTableID: tblWithDailyInterval.ID, LastJobStartTime: time.Now().Add(-time.Hour * 25)}}, false}, + {"last start time far enough for 24h", []*cache.PhysicalTable{tblWithDailyInterval}, []*cache.TableStatus{{TableID: tblWithDailyInterval.ID, ParentTableID: tblWithDailyInterval.ID, LastJobStartTime: se.Now().Add(-time.Hour * 25)}}, false}, } for _, c := range cases { @@ -366,7 +372,7 @@ func TestLockTable(t *testing.T) { oldJobExpireTime := now.Add(-time.Hour) oldJobStartTime := now.Add(-30 * time.Minute) - testPhysicalTable := &cache.PhysicalTable{ID: 1, Schema: model.NewCIStr("test"), TableInfo: &model.TableInfo{ID: 1, Name: model.NewCIStr("t1"), TTLInfo: &model.TTLInfo{ColumnName: model.NewCIStr("test"), IntervalExprStr: "5 Year", JobInterval: "1h"}}} + testPhysicalTable := &cache.PhysicalTable{ID: 1, Schema: model.NewCIStr("test"), TableInfo: &model.TableInfo{ID: 1, Name: model.NewCIStr("t1"), TTLInfo: &model.TTLInfo{ColumnName: model.NewCIStr("test"), IntervalExprStr: "1", IntervalTimeUnit: int(ast.TimeUnitMinute), JobInterval: "1h"}}} type executeInfo struct { sql string @@ -600,8 +606,8 @@ func TestLockTable(t *testing.T) { sqlCounter += 1 return } - se.evalExpire = newJobExpireTime + m.ctx = cache.SetMockExpireTime(context.Background(), newJobExpireTime) var job *ttlJob if c.isCreate { job, err = m.lockNewJob(context.Background(), se, c.table, now, "new-job-id", c.checkInterval) diff --git a/pkg/ttl/ttlworker/scan.go b/pkg/ttl/ttlworker/scan.go index 71b696aefa..2a28b2dc13 100644 --- a/pkg/ttl/ttlworker/scan.go +++ b/pkg/ttl/ttlworker/scan.go @@ -111,7 +111,7 @@ func (t *ttlScanTask) doScan(ctx context.Context, delCh chan<- *ttlDeleteTask, s } defer rawSess.Close() - safeExpire, err := t.tbl.EvalExpireTime(taskCtx, rawSess, time.Now()) + safeExpire, err := t.tbl.EvalExpireTime(taskCtx, rawSess, rawSess.Now()) if err != nil { return t.result(err) } @@ -127,7 +127,7 @@ func (t *ttlScanTask) doScan(ctx context.Context, delCh chan<- *ttlDeleteTask, s // because `ExecuteSQLWithCheck` only do checks when the table meta used by task is different with the latest one. // In this case, some rows will be deleted unexpectedly. if t.ExpireTime.After(safeExpire) { - return t.result(errors.Errorf("current expire time is after safe expire time. (%d > %d)", t.ExpireTime.UnixMilli(), safeExpire.UnixMilli())) + return t.result(errors.Errorf("current expire time is after safe expire time. (%d > %d)", t.ExpireTime.Unix(), safeExpire.Unix())) } origConcurrency := rawSess.GetSessionVars().DistSQLScanConcurrency() diff --git a/pkg/ttl/ttlworker/scan_test.go b/pkg/ttl/ttlworker/scan_test.go index 5177e175d4..137e72fe9a 100644 --- a/pkg/ttl/ttlworker/scan_test.go +++ b/pkg/ttl/ttlworker/scan_test.go @@ -135,7 +135,7 @@ func TestScanWorkerSchedule(t *testing.T) { defer w.stopWithWait() task := &ttlScanTask{ - ctx: context.Background(), + ctx: cache.SetMockExpireTime(context.Background(), time.Now()), tbl: tbl, TTLTask: &cache.TTLTask{ ExpireTime: time.UnixMilli(0), @@ -184,7 +184,7 @@ func TestScanWorkerScheduleWithFailedTask(t *testing.T) { defer w.stopWithWait() task := &ttlScanTask{ - ctx: context.Background(), + ctx: cache.SetMockExpireTime(context.Background(), time.Now()), tbl: tbl, TTLTask: &cache.TTLTask{ ExpireTime: time.UnixMilli(0), @@ -392,6 +392,7 @@ func (t *mockScanTask) execSQL(_ context.Context, sql string, _ ...any) ([]chunk func TestScanTaskDoScan(t *testing.T) { task := newMockScanTask(t, 3) + task.ctx = cache.SetMockExpireTime(task.ctx, time.Now()) task.sqlRetry[1] = scanTaskExecuteSQLMaxRetry task.runDoScanForTest(3, "") @@ -412,13 +413,13 @@ func TestScanTaskDoScan(t *testing.T) { func TestScanTaskCheck(t *testing.T) { tbl := newMockTTLTbl(t, "t1") pool := newMockSessionPool(t, tbl) - pool.se.evalExpire = time.UnixMilli(100) pool.se.rows = newMockRows(t, types.NewFieldType(mysql.TypeInt24)).Append(12).Rows() + ctx := cache.SetMockExpireTime(context.Background(), time.Unix(100, 0)) task := &ttlScanTask{ - ctx: context.Background(), + ctx: ctx, TTLTask: &cache.TTLTask{ - ExpireTime: time.UnixMilli(101).Add(time.Minute), + ExpireTime: time.Unix(101, 0).Add(time.Minute), }, tbl: tbl, statistics: &ttlStatistics{}, @@ -427,14 +428,14 @@ func TestScanTaskCheck(t *testing.T) { ch := make(chan *ttlDeleteTask, 1) result := task.doScan(context.Background(), ch, pool) require.Equal(t, task, result.task) - require.EqualError(t, result.err, "current expire time is after safe expire time. (60101 > 60100)") + require.EqualError(t, result.err, "current expire time is after safe expire time. (161 > 160)") require.Equal(t, 0, len(ch)) require.Equal(t, "Total Rows: 0, Success Rows: 0, Error Rows: 0", task.statistics.String()) task = &ttlScanTask{ - ctx: context.Background(), + ctx: ctx, TTLTask: &cache.TTLTask{ - ExpireTime: time.UnixMilli(100).Add(time.Minute), + ExpireTime: time.Unix(100, 0).Add(time.Minute), }, tbl: tbl, statistics: &ttlStatistics{}, diff --git a/pkg/ttl/ttlworker/session.go b/pkg/ttl/ttlworker/session.go index b9085f4a1d..f5b3574998 100644 --- a/pkg/ttl/ttlworker/session.go +++ b/pkg/ttl/ttlworker/session.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/pkg/ttl/metrics" "github.com/pingcap/tidb/pkg/ttl/session" "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/sqlexec" "go.uber.org/zap" @@ -85,19 +86,30 @@ func getSession(pool sessionPool) (session.Session, error) { originalRetryLimit := sctx.GetSessionVars().RetryLimit originalEnable1PC := sctx.GetSessionVars().Enable1PC originalEnableAsyncCommit := sctx.GetSessionVars().EnableAsyncCommit + originalTimeZone, restoreTimeZone := "", false + se := session.NewSession(sctx, exec, func(se session.Session) { _, err = se.ExecuteSQL(context.Background(), fmt.Sprintf("set tidb_retry_limit=%d", originalRetryLimit)) if err != nil { + intest.AssertNoError(err) logutil.BgLogger().Error("fail to reset tidb_retry_limit", zap.Int64("originalRetryLimit", originalRetryLimit), zap.Error(err)) } if !originalEnable1PC { _, err = se.ExecuteSQL(context.Background(), "set tidb_enable_1pc=OFF") + intest.AssertNoError(err) terror.Log(err) } if !originalEnableAsyncCommit { _, err = se.ExecuteSQL(context.Background(), "set tidb_enable_async_commit=OFF") + intest.AssertNoError(err) + terror.Log(err) + } + + if restoreTimeZone { + _, err = se.ExecuteSQL(context.Background(), "set @@time_zone=%?", originalTimeZone) + intest.AssertNoError(err) terror.Log(err) } @@ -135,6 +147,26 @@ func getSession(pool sessionPool) (session.Session, error) { return nil, err } + // set the time zone to UTC + rows, err := se.ExecuteSQL(context.Background(), "select @@time_zone") + if err != nil { + se.Close() + return nil, err + } + + if len(rows) == 0 || rows[0].Len() == 0 { + se.Close() + return nil, errors.New("failed to get time_zone variable") + } + originalTimeZone = rows[0].GetString(0) + + _, err = se.ExecuteSQL(context.Background(), "set @@time_zone='UTC'") + if err != nil { + se.Close() + return nil, err + } + restoreTimeZone = true + return se, nil } diff --git a/pkg/ttl/ttlworker/session_test.go b/pkg/ttl/ttlworker/session_test.go index d777efc7bb..1d531855e1 100644 --- a/pkg/ttl/ttlworker/session_test.go +++ b/pkg/ttl/ttlworker/session_test.go @@ -98,6 +98,10 @@ func (r *mockRows) Append(row ...any) *mockRows { val, ok := row[i].(int) require.True(r.t, ok) r.AppendInt64(i, int64(val)) + case mysql.TypeString: + val, ok := row[i].(string) + require.True(r.t, ok) + r.AppendString(i, val) default: require.FailNow(r.t, "unsupported tp %v", tp) } @@ -141,7 +145,6 @@ type mockSession struct { executeSQL func(ctx context.Context, sql string, args ...any) ([]chunk.Row, error) rows []chunk.Row execErr error - evalExpire time.Time resetTimeZoneCalls int closed bool commitErr error @@ -157,7 +160,6 @@ func newMockSession(t *testing.T, tbl ...*cache.PhysicalTable) *mockSession { return &mockSession{ t: t, sessionInfoSchema: newMockInfoSchema(tbls...), - evalExpire: time.Now(), sessionVars: sessVars, } } @@ -179,7 +181,11 @@ func (s *mockSession) GetSessionVars() *variable.SessionVars { func (s *mockSession) ExecuteSQL(ctx context.Context, sql string, args ...any) ([]chunk.Row, error) { require.False(s.t, s.closed) if strings.HasPrefix(strings.ToUpper(sql), "SELECT FROM_UNIXTIME") { - return newMockRows(s.t, types.NewFieldType(mysql.TypeTimestamp)).Append(s.evalExpire.In(s.GetSessionVars().TimeZone)).Rows(), nil + panic("not supported") + } + + if strings.ToUpper(sql) == "SELECT @@TIME_ZONE" { + panic("not supported") } if strings.HasPrefix(strings.ToUpper(sql), "SET ") { @@ -206,6 +212,11 @@ func (s *mockSession) ResetWithGlobalTimeZone(_ context.Context) (err error) { return nil } +// GlobalTimeZone returns the global timezone +func (s *mockSession) GlobalTimeZone(_ context.Context) (*time.Location, error) { + return time.Local, nil +} + func (s *mockSession) Close() { s.closed = true } @@ -263,7 +274,7 @@ func TestValidateTTLWork(t *testing.T) { s := newMockSession(t, tbl) s.execErr = errors.New("mockErr") - s.evalExpire = time.UnixMilli(0).In(time.UTC) + ctx = cache.SetMockExpireTime(ctx, time.UnixMilli(0).In(time.UTC)) // test table dropped s.sessionInfoSchema = newMockInfoSchema() @@ -311,13 +322,13 @@ func TestValidateTTLWork(t *testing.T) { tbl2 = tbl.TableInfo.Clone() tbl2.TTLInfo.IntervalExprStr = "10" s.sessionInfoSchema = newMockInfoSchema(tbl2) - s.evalExpire = time.UnixMilli(-1) + ctx = cache.SetMockExpireTime(ctx, time.UnixMilli(-1)) err = validateTTLWork(ctx, s, tbl, expire) require.EqualError(t, err, "expire interval changed") tbl2 = tbl.TableInfo.Clone() tbl2.TTLInfo.IntervalTimeUnit = int(ast.TimeUnitDay) - s.evalExpire = time.UnixMilli(-1) + ctx = cache.SetMockExpireTime(ctx, time.UnixMilli(-1)) s.sessionInfoSchema = newMockInfoSchema(tbl2) err = validateTTLWork(ctx, s, tbl, expire) require.EqualError(t, err, "expire interval changed") @@ -328,7 +339,7 @@ func TestValidateTTLWork(t *testing.T) { tbl2.Columns[0].ID += 10 tbl2.Columns[0].FieldType = *types.NewFieldType(mysql.TypeDate) tbl2.TTLInfo.IntervalExprStr = "100" - s.evalExpire = time.UnixMilli(1000) + ctx = cache.SetMockExpireTime(ctx, time.UnixMilli(1000)) s.sessionInfoSchema = newMockInfoSchema(tbl2) err = validateTTLWork(ctx, s, tbl, expire) require.NoError(t, err) diff --git a/pkg/ttl/ttlworker/task_manager.go b/pkg/ttl/ttlworker/task_manager.go index af64d3ce15..375a19c9e7 100644 --- a/pkg/ttl/ttlworker/task_manager.go +++ b/pkg/ttl/ttlworker/task_manager.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/pkg/ttl/cache" "github.com/pingcap/tidb/pkg/ttl/metrics" "github.com/pingcap/tidb/pkg/ttl/session" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/tikv/client-go/v2/tikv" "github.com/tikv/client-go/v2/tikvrpc" @@ -321,6 +322,7 @@ loop: } func (m *taskManager) peekWaitingScanTasks(se session.Session, now time.Time) ([]*cache.TTLTask, error) { + intest.Assert(se.GetSessionVars().Location().String() == now.Location().String()) sql, args := cache.PeekWaitingTTLTask(now.Add(-getTaskManagerHeartBeatExpireInterval())) rows, err := se.ExecuteSQL(m.ctx, sql, args...) if err != nil { @@ -372,6 +374,7 @@ func (m *taskManager) lockScanTask(se session.Session, task *cache.TTLTask, now return errors.WithStack(errTooManyRunningTasks) } + intest.Assert(se.GetSessionVars().Location().String() == now.Location().String()) sql, args := setTTLTaskOwnerSQL(task.JobID, task.ScanID, m.id, now) _, err = se.ExecuteSQL(ctx, sql, args...) if err != nil { @@ -440,6 +443,7 @@ func (m *taskManager) updateHeartBeat(ctx context.Context, se session.Session, n state.ScanTaskErr = task.result.err.Error() } + intest.Assert(se.GetSessionVars().Location().String() == now.Location().String()) sql, args, err := updateTTLTaskHeartBeatSQL(task.JobID, task.ScanID, now, state) if err != nil { return err @@ -482,6 +486,7 @@ func (m *taskManager) reportTaskFinished(se session.Session, now time.Time, task state.ScanTaskErr = task.result.err.Error() } + intest.Assert(se.GetSessionVars().Location().String() == now.Location().String()) sql, args, err := setTTLTaskFinishedSQL(task.JobID, task.ScanID, state, now) if err != nil { return err diff --git a/pkg/ttl/ttlworker/task_manager_integration_test.go b/pkg/ttl/ttlworker/task_manager_integration_test.go index dfa839ee13..02ed0f3c11 100644 --- a/pkg/ttl/ttlworker/task_manager_integration_test.go +++ b/pkg/ttl/ttlworker/task_manager_integration_test.go @@ -49,7 +49,7 @@ func TestParallelLockNewTask(t *testing.T) { sessionFactory := sessionFactory(t, store) se := sessionFactory() - now := time.Now() + now := se.Now() isc := cache.NewInfoSchemaCache(time.Minute) require.NoError(t, isc.Update(se)) @@ -129,7 +129,6 @@ func TestParallelSchedule(t *testing.T) { } isc := cache.NewInfoSchemaCache(time.Second) require.NoError(t, isc.Update(sessionFactory())) - now := time.Now() scheduleWg := sync.WaitGroup{} finishTasks := make([]func(), 0, 4) for i := 0; i < 4; i++ { @@ -146,7 +145,7 @@ func TestParallelSchedule(t *testing.T) { scheduleWg.Add(1) go func() { se := sessionFactory() - m.RescheduleTasks(se, now) + m.RescheduleTasks(se, se.Now()) scheduleWg.Done() }() finishTasks = append(finishTasks, func() { @@ -154,7 +153,7 @@ func TestParallelSchedule(t *testing.T) { for _, task := range m.GetRunningTasks() { require.Nil(t, task.Context().Err(), fmt.Sprintf("%s %d", managerID, task.ScanID)) task.SetResult(nil) - m.CheckFinishedTask(se, time.Now()) + m.CheckFinishedTask(se, se.Now()) require.NotNil(t, task.Context().Err(), fmt.Sprintf("%s %d", managerID, task.ScanID)) } }) @@ -188,14 +187,15 @@ func TestTaskScheduleExpireHeartBeat(t *testing.T) { // update the infoschema cache isc := cache.NewInfoSchemaCache(time.Second) require.NoError(t, isc.Update(sessionFactory())) - now := time.Now() // schedule in a task manager scanWorker := ttlworker.NewMockScanWorker(t) scanWorker.Start() m := ttlworker.NewTaskManager(context.Background(), nil, isc, "task-manager-1", store) m.SetScanWorkers4Test([]ttlworker.Worker{scanWorker}) - m.RescheduleTasks(sessionFactory(), now) + se := sessionFactory() + now := se.Now() + m.RescheduleTasks(se, now) tk.MustQuery("select status,owner_id from mysql.tidb_ttl_task").Check(testkit.Rows("running task-manager-1")) // another task manager should fetch this task after heartbeat expire @@ -235,13 +235,14 @@ func TestTaskMetrics(t *testing.T) { // update the infoschema cache isc := cache.NewInfoSchemaCache(time.Second) require.NoError(t, isc.Update(sessionFactory())) - now := time.Now() // schedule in a task manager scanWorker := ttlworker.NewMockScanWorker(t) scanWorker.Start() m := ttlworker.NewTaskManager(context.Background(), nil, isc, "task-manager-1", store) m.SetScanWorkers4Test([]ttlworker.Worker{scanWorker}) + se := sessionFactory() + now := se.Now() m.RescheduleTasks(sessionFactory(), now) tk.MustQuery("select status,owner_id from mysql.tidb_ttl_task").Check(testkit.Rows("running task-manager-1")) @@ -262,9 +263,10 @@ func TestRescheduleWithError(t *testing.T) { sql := fmt.Sprintf("insert into mysql.tidb_ttl_task(job_id,table_id,scan_id,expire_time,created_time) values ('test-job', %d, %d, NOW(), NOW())", 613, 1) tk.MustExec(sql) + se := sessionFactory() + now := se.Now() isc := cache.NewInfoSchemaCache(time.Second) - require.NoError(t, isc.Update(sessionFactory())) - now := time.Now() + require.NoError(t, isc.Update(se)) // schedule in a task manager scanWorker := ttlworker.NewMockScanWorker(t) @@ -304,7 +306,6 @@ func TestTTLRunningTasksLimitation(t *testing.T) { } isc := cache.NewInfoSchemaCache(time.Second) require.NoError(t, isc.Update(sessionFactory())) - now := time.Now() scheduleWg := sync.WaitGroup{} for i := 0; i < 16; i++ { workers := []ttlworker.Worker{} @@ -320,7 +321,7 @@ func TestTTLRunningTasksLimitation(t *testing.T) { scheduleWg.Add(1) go func() { se := sessionFactory() - m.RescheduleTasks(se, now) + m.RescheduleTasks(se, se.Now()) scheduleWg.Done() }() } diff --git a/pkg/ttl/ttlworker/timer.go b/pkg/ttl/ttlworker/timer.go index e4ca260b47..751b2bb10c 100644 --- a/pkg/ttl/ttlworker/timer.go +++ b/pkg/ttl/ttlworker/timer.go @@ -50,6 +50,8 @@ type TTLJobTrace struct { // TTLJobAdapter is used to submit TTL job and trace job status type TTLJobAdapter interface { + // Now returns the current time with system timezone. + Now() (time.Time, error) // CanSubmitJob returns whether a new job can be created for the specified table CanSubmitJob(tableID, physicalID int64) bool // SubmitJob submits a new job @@ -64,7 +66,6 @@ type ttlTimerHook struct { ctx context.Context cancel func() wg sync.WaitGroup - nowFunc func() time.Time checkTTLJobInterval time.Duration // waitJobLoopCounter is only used for test waitJobLoopCounter int64 @@ -77,7 +78,6 @@ func newTTLTimerHook(adapter TTLJobAdapter, cli timerapi.TimerClient) *ttlTimerH cli: cli, ctx: ctx, cancel: cancel, - nowFunc: time.Now, checkTTLJobInterval: defaultCheckTTLJobInterval, } } @@ -95,8 +95,13 @@ func (t *ttlTimerHook) OnPreSchedEvent(_ context.Context, event timerapi.TimerSh return } + now, err := t.adapter.Now() + if err != nil { + return r, err + } + windowStart, windowEnd := variable.TTLJobScheduleWindowStartTime.Load(), variable.TTLJobScheduleWindowEndTime.Load() - if !timeutil.WithinDayTimePeriod(windowStart, windowEnd, t.nowFunc()) { + if !timeutil.WithinDayTimePeriod(windowStart, windowEnd, now) { r.Delay = time.Minute return } @@ -154,7 +159,12 @@ func (t *ttlTimerHook) OnSchedEvent(ctx context.Context, event timerapi.TimerShe logger.Warn("cancel current TTL timer event because table's ttl is not enabled") } - if t.nowFunc().Sub(timer.EventStart) > 10*time.Minute { + now, err := t.adapter.Now() + if err != nil { + return err + } + + if now.Sub(timer.EventStart) > 10*time.Minute { cancel = true logger.Warn("cancel current TTL timer event because job not submitted for a long time") } diff --git a/pkg/ttl/ttlworker/timer_test.go b/pkg/ttl/ttlworker/timer_test.go index 11b05c4b53..467e07bc43 100644 --- a/pkg/ttl/ttlworker/timer_test.go +++ b/pkg/ttl/ttlworker/timer_test.go @@ -55,6 +55,14 @@ func (a *mockJobAdapter) GetJob(ctx context.Context, tableID, physicalID int64, return job, args.Error(1) } +func (a *mockJobAdapter) Now() (now time.Time, _ error) { + args := a.Called() + if obj := args.Get(0); obj != nil { + now = obj.(time.Time) + } + return now, nil +} + type mockTimerCli struct { mock.Mock timerapi.TimerClient @@ -183,6 +191,7 @@ func TestTTLTimerHookPrepare(t *testing.T) { // normal adapter.On("CanSubmitJob", data.TableID, data.PhysicalID).Return(true).Once() + adapter.On("Now").Return(time.Now()).Once() r, err := hook.OnPreSchedEvent(context.TODO(), &mockTimerSchedEvent{eventID: "event1", timer: timer}) require.NoError(t, err) require.Equal(t, timerapi.PreSchedEventResult{}, r) @@ -197,9 +206,7 @@ func TestTTLTimerHookPrepare(t *testing.T) { // not in window now := time.Date(2023, 1, 1, 15, 10, 0, 0, time.UTC) - hook.nowFunc = func() time.Time { - return now - } + adapter.On("Now").Return(now, nil).Once() clearTTLWindowAndEnable() variable.TTLJobScheduleWindowStartTime.Store(time.Date(0, 0, 0, 15, 11, 0, 0, time.UTC)) r, err = hook.OnPreSchedEvent(context.TODO(), &mockTimerSchedEvent{eventID: "event1", timer: timer}) @@ -208,6 +215,7 @@ func TestTTLTimerHookPrepare(t *testing.T) { adapter.AssertExpectations(t) clearTTLWindowAndEnable() + adapter.On("Now").Return(now, nil).Once() variable.TTLJobScheduleWindowEndTime.Store(time.Date(0, 0, 0, 15, 9, 0, 0, time.UTC)) r, err = hook.OnPreSchedEvent(context.TODO(), &mockTimerSchedEvent{eventID: "event1", timer: timer}) require.NoError(t, err) @@ -216,6 +224,7 @@ func TestTTLTimerHookPrepare(t *testing.T) { // in window clearTTLWindowAndEnable() + adapter.On("Now").Return(now, nil).Once() adapter.On("CanSubmitJob", data.TableID, data.PhysicalID).Return(true).Once() variable.TTLJobScheduleWindowStartTime.Store(time.Date(0, 0, 0, 15, 9, 0, 0, time.UTC)) variable.TTLJobScheduleWindowEndTime.Store(time.Date(0, 0, 0, 15, 11, 0, 0, time.UTC)) @@ -226,6 +235,7 @@ func TestTTLTimerHookPrepare(t *testing.T) { // CanSubmitJob returns false clearTTLWindowAndEnable() + adapter.On("Now").Return(now, nil).Once() adapter.On("CanSubmitJob", data.TableID, data.PhysicalID).Return(false).Once() r, err = hook.OnPreSchedEvent(context.TODO(), &mockTimerSchedEvent{eventID: "event1", timer: timer}) require.NoError(t, err) @@ -267,6 +277,7 @@ func TestTTLTimerHookOnEvent(t *testing.T) { adapter.On("SubmitJob", ctx, data.TableID, data.PhysicalID, timer.EventID, timer.EventStart). Return(nil, errors.New("mockSubmitErr")). Once() + adapter.On("Now").Return(time.Now()).Once() err = hook.OnSchedEvent(ctx, &mockTimerSchedEvent{eventID: timer.EventID, timer: timer}) require.EqualError(t, err, "mockSubmitErr") adapter.AssertExpectations(t) @@ -305,6 +316,7 @@ func TestTTLTimerHookOnEvent(t *testing.T) { adapter.On("GetJob", hook.ctx, data.TableID, data.PhysicalID, timer.EventID). Return(&TTLJobTrace{RequestID: timer.EventID, Finished: true, Summary: summary.LastJobSummary}, nil). Once() + adapter.On("Now").Return(time.Now()).Once() err = hook.OnSchedEvent(ctx, &mockTimerSchedEvent{eventID: timer.EventID, timer: timer}) require.NoError(t, err) require.Equal(t, int64(1), hook.waitJobLoopCounter) @@ -347,6 +359,7 @@ func TestTTLTimerHookOnEvent(t *testing.T) { adapter.On("CanSubmitJob", data.TableID, data.PhysicalID). Return(false). Once() + adapter.On("Now").Return(time.Now()).Once() err = hook.OnSchedEvent(ctx, &mockTimerSchedEvent{eventID: timer.EventID, timer: timer}) require.NoError(t, err) adapter.AssertExpectations(t) @@ -366,6 +379,7 @@ func TestTTLTimerHookOnEvent(t *testing.T) { Return(nil, nil). Once() require.False(t, timer.Enable) + adapter.On("Now").Return(time.Now()).Once() err = hook.OnSchedEvent(ctx, &mockTimerSchedEvent{eventID: timer.EventID, timer: timer}) require.NoError(t, err) adapter.AssertExpectations(t) @@ -382,9 +396,7 @@ func TestTTLTimerHookOnEvent(t *testing.T) { watermark = time.Unix(3600*789, 0) require.NoError(t, cli.UpdateTimer(ctx, timer.ID, timerapi.WithSetWatermark(watermark))) timer = triggerTestTimer(t, store, timer.ID) - hook.nowFunc = func() time.Time { - return timer.EventStart.Add(11 * time.Minute) - } + adapter.On("Now").Return(timer.EventStart.Add(11*time.Minute), nil).Once() adapter.On("GetJob", ctx, data.TableID, data.PhysicalID, timer.EventID). Return(nil, nil). Once()