Files
tidb/timer/api/mem_store.go

225 lines
4.6 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"encoding/hex"
"sync"
"time"
"github.com/google/uuid"
"github.com/pingcap/errors"
"golang.org/x/exp/slices"
)
type memStoreWatcher struct {
ctx context.Context
ch chan WatchTimerResponse
}
type memoryStoreCore struct {
mu sync.Mutex
namespaces map[string]map[string]*TimerRecord
id2Timers map[string]*TimerRecord
watchers []*memStoreWatcher
}
// NewMemoryTimerStore creates a memory store for timers
func NewMemoryTimerStore() *TimerStore {
return &TimerStore{
TimerStoreCore: &memoryStoreCore{
namespaces: make(map[string]map[string]*TimerRecord),
id2Timers: make(map[string]*TimerRecord),
},
}
}
func (s *memoryStoreCore) Create(_ context.Context, record *TimerRecord) (string, error) {
if record == nil {
return "", errors.New("timer should not be nil")
}
if err := record.Validate(); err != nil {
return "", err
}
s.mu.Lock()
defer s.mu.Unlock()
record = record.Clone()
if record.ID == "" {
uid := uuid.New()
record.ID = hex.EncodeToString(uid[:])
}
if record.Version == 0 {
record.Version = 1
}
if record.EventStatus == "" {
record.EventStatus = SchedEventIdle
}
if record.CreateTime.IsZero() {
record.CreateTime = time.Now()
}
if _, ok := s.id2Timers[record.ID]; ok {
return "", errors.Trace(ErrTimerExists)
}
ns, ok := s.namespaces[record.Namespace]
if !ok {
ns = make(map[string]*TimerRecord)
s.namespaces[record.Namespace] = ns
} else {
if _, ok = ns[record.Key]; ok {
return "", errors.Trace(ErrTimerExists)
}
}
s.id2Timers[record.ID] = record
ns[record.Key] = record
s.notify(WatchTimerEventCreate, record.ID)
return record.ID, nil
}
func (s *memoryStoreCore) List(_ context.Context, cond Cond) ([]*TimerRecord, error) {
result := make([]*TimerRecord, 0, 1)
for _, ns := range s.namespaces {
for _, t := range ns {
if cond == nil || cond.Match(t) {
result = append(result, t.Clone())
}
}
}
return result, nil
}
func (s *memoryStoreCore) Update(_ context.Context, timerID string, update *TimerUpdate) error {
if update == nil {
return errors.New("update should not be nil")
}
s.mu.Lock()
defer s.mu.Unlock()
record, ok := s.id2Timers[timerID]
if !ok {
return ErrTimerNotExist
}
newRecord, err := update.Apply(record)
if err != nil {
return err
}
if err = newRecord.Validate(); err != nil {
return err
}
newRecord.Version++
s.id2Timers[timerID] = newRecord
s.namespaces[record.Namespace][record.Key] = newRecord
s.notify(WatchTimerEventUpdate, timerID)
return nil
}
func (s *memoryStoreCore) Delete(_ context.Context, timerID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
record, ok := s.id2Timers[timerID]
if !ok {
return false, nil
}
delete(s.id2Timers, timerID)
ns := s.namespaces[record.Namespace]
delete(ns, record.Key)
if len(ns) == 0 {
delete(s.namespaces, record.Namespace)
}
s.notify(WatchTimerEventDelete, timerID)
return true, nil
}
func (*memoryStoreCore) WatchSupported() bool {
return true
}
func (s *memoryStoreCore) Watch(ctx context.Context) WatchTimerChan {
s.mu.Lock()
defer s.mu.Unlock()
watcher := &memStoreWatcher{
ctx: ctx,
ch: make(chan WatchTimerResponse, 8),
}
s.watchers = append(s.watchers, watcher)
ch := make(chan WatchTimerResponse)
go func() {
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
close(ch)
if i := slices.Index(s.watchers, watcher); i >= 0 {
s.watchers = slices.Delete(s.watchers, i, i+1)
}
}()
for {
select {
case <-ctx.Done():
return
case resp := <-watcher.ch:
select {
case <-ctx.Done():
return
case ch <- resp:
}
}
}
}()
return ch
}
func (s *memoryStoreCore) notify(tp WatchTimerEventType, timerID string) {
resp := WatchTimerResponse{
Events: []*WatchTimerEvent{
{
Tp: tp,
TimerID: timerID,
},
},
}
for _, w := range s.watchers {
select {
case <-w.ctx.Done():
return
case w.ch <- resp:
default:
watcher := w
go func() {
select {
case <-watcher.ctx.Done():
return
case watcher.ch <- resp:
}
}()
}
}
}