Files
tidb/pkg/timer/api/mem_store.go

327 lines
6.8 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"encoding/hex"
"slices"
"sync"
"time"
"github.com/google/uuid"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/util/timeutil"
)
type memStoreWatcher struct {
ctx context.Context
ch chan WatchTimerResponse
}
type memoryStoreCore struct {
mu sync.RWMutex
namespaces map[string]map[string]*TimerRecord
id2Timers map[string]*TimerRecord
notifier TimerWatchEventNotifier
}
// NewMemoryTimerStore creates a memory store for timers
func NewMemoryTimerStore() *TimerStore {
return &TimerStore{
TimerStoreCore: &memoryStoreCore{
namespaces: make(map[string]map[string]*TimerRecord),
id2Timers: make(map[string]*TimerRecord),
notifier: NewMemTimerWatchEventNotifier(),
},
}
}
func (s *memoryStoreCore) Create(_ context.Context, record *TimerRecord) (string, error) {
if record == nil {
return "", errors.New("timer should not be nil")
}
if record.ID != "" {
return "", errors.New("ID should not be specified when create record")
}
if record.Version != 0 {
return "", errors.New("Version should not be specified when create record")
}
if !record.CreateTime.IsZero() {
return "", errors.New("CreateTime should not be specified when create record")
}
if err := record.Validate(); err != nil {
return "", err
}
s.mu.Lock()
defer s.mu.Unlock()
record = record.Clone()
uid := uuid.New()
record.ID = hex.EncodeToString(uid[:])
record.Location = getMemStoreTimeZoneLoc(record.TimeZone)
record.Version = 1
record.CreateTime = time.Now()
if record.EventStatus == "" {
record.EventStatus = SchedEventIdle
}
normalizeTimeFields(record)
if _, ok := s.id2Timers[record.ID]; ok {
return "", errors.Trace(ErrTimerExists)
}
ns, ok := s.namespaces[record.Namespace]
if !ok {
ns = make(map[string]*TimerRecord)
s.namespaces[record.Namespace] = ns
} else {
if _, ok = ns[record.Key]; ok {
return "", errors.Trace(ErrTimerExists)
}
}
s.id2Timers[record.ID] = record
ns[record.Key] = record
s.notifier.Notify(WatchTimerEventCreate, record.ID)
return record.ID, nil
}
func (s *memoryStoreCore) List(_ context.Context, cond Cond) ([]*TimerRecord, error) {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]*TimerRecord, 0, 1)
for _, ns := range s.namespaces {
for _, t := range ns {
if cond == nil || cond.Match(t) {
result = append(result, t.Clone())
}
}
}
return result, nil
}
func (s *memoryStoreCore) Update(_ context.Context, timerID string, update *TimerUpdate) error {
if update == nil {
return errors.New("update should not be nil")
}
s.mu.Lock()
defer s.mu.Unlock()
record, ok := s.id2Timers[timerID]
if !ok {
return ErrTimerNotExist
}
newRecord, err := update.apply(record)
if err != nil {
return err
}
normalizeTimeFields(newRecord)
if err = newRecord.Validate(); err != nil {
return err
}
newRecord.Version++
s.id2Timers[timerID] = newRecord
s.namespaces[record.Namespace][record.Key] = newRecord
s.notifier.Notify(WatchTimerEventUpdate, timerID)
return nil
}
func (s *memoryStoreCore) Delete(_ context.Context, timerID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
record, ok := s.id2Timers[timerID]
if !ok {
return false, nil
}
delete(s.id2Timers, timerID)
ns := s.namespaces[record.Namespace]
delete(ns, record.Key)
if len(ns) == 0 {
delete(s.namespaces, record.Namespace)
}
s.notifier.Notify(WatchTimerEventDelete, timerID)
return true, nil
}
func (*memoryStoreCore) WatchSupported() bool {
return true
}
func (s *memoryStoreCore) Watch(ctx context.Context) WatchTimerChan {
return s.notifier.Watch(ctx)
}
func (s *memoryStoreCore) Close() {
s.notifier.Close()
}
type memTimerWatchEventNotifier struct {
mu sync.RWMutex
ctx context.Context
wg sync.WaitGroup
cancel func()
watchers []*memStoreWatcher
}
// NewMemTimerWatchEventNotifier creates a notifier with memory implement
func NewMemTimerWatchEventNotifier() TimerWatchEventNotifier {
ctx, cancel := context.WithCancel(context.Background())
return &memTimerWatchEventNotifier{
ctx: ctx,
cancel: cancel,
watchers: make([]*memStoreWatcher, 0, 8),
}
}
func (n *memTimerWatchEventNotifier) Watch(ctx context.Context) WatchTimerChan {
n.mu.Lock()
defer n.mu.Unlock()
ch := make(chan WatchTimerResponse)
if n.cancel == nil {
close(ch)
return ch
}
watcher := &memStoreWatcher{
ctx: ctx,
ch: make(chan WatchTimerResponse, 8),
}
n.watchers = append(n.watchers, watcher)
n.wg.Add(1)
go func() {
defer func() {
n.mu.Lock()
defer n.mu.Unlock()
close(ch)
if i := slices.Index(n.watchers, watcher); i >= 0 {
n.watchers = slices.Delete(n.watchers, i, i+1)
}
n.wg.Done()
}()
for {
select {
case <-n.ctx.Done():
return
case <-ctx.Done():
return
case resp := <-watcher.ch:
select {
case <-n.ctx.Done():
return
case <-ctx.Done():
return
case ch <- resp:
}
}
}
}()
return ch
}
func (n *memTimerWatchEventNotifier) Notify(tp WatchTimerEventType, timerID string) {
n.mu.RLock()
defer n.mu.RUnlock()
if n.cancel == nil {
return
}
resp := WatchTimerResponse{
Events: []*WatchTimerEvent{
{
Tp: tp,
TimerID: timerID,
},
},
}
for _, w := range n.watchers {
select {
case <-n.ctx.Done():
return
case <-w.ctx.Done():
continue
case w.ch <- resp:
default:
watcher := w
n.wg.Add(1)
go func() {
defer n.wg.Done()
select {
case <-n.ctx.Done():
return
case <-watcher.ctx.Done():
return
case watcher.ch <- resp:
}
}()
}
}
}
func (n *memTimerWatchEventNotifier) Close() {
n.mu.Lock()
if n.cancel != nil {
n.cancel()
n.cancel = nil
}
n.mu.Unlock()
n.wg.Wait()
}
func getMemStoreTimeZoneLoc(tz string) *time.Location {
if tz == "" {
return timeutil.SystemLocation()
}
if loc, err := timeutil.ParseTimeZone(tz); err == nil {
return loc
}
return timeutil.SystemLocation()
}
func normalizeTimeFields(record *TimerRecord) {
if record.Location == nil {
return
}
if !record.Watermark.IsZero() {
record.Watermark = record.Watermark.In(record.Location)
}
if !record.EventStart.IsZero() {
record.EventStart = record.EventStart.In(record.Location)
}
if !record.CreateTime.IsZero() {
record.CreateTime = record.CreateTime.In(record.Location)
}
}