Files
tidb/pkg/expression/generator/builtin_threadsafe.go

236 lines
6.6 KiB
Go

// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"log"
"os"
"path"
"sort"
"strings"
)
var (
specialSafeFuncs = map[string]struct{}{
"builtinInIntSig": {},
"builtinInStringSig": {},
"builtinInRealSig": {},
"builtinInDecimalSig": {},
"builtinInTimeSig": {},
"builtinInDurationSig": {},
"builtinRealIsTrueSig": {},
"builtinDecimalIsTrueSig": {},
"builtinIntIsTrueSig": {},
"builtinRealIsFalseSig": {},
"builtinDecimalIsFalseSig": {},
"builtinIntIsFalseSig": {},
// NOTE: please make sure there are test cases for all functions here.
}
)
func collectThreadSafeBuiltinFuncs(file string) (safeFuncNames, unsafeFuncNames []string) {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, file, nil, 0)
if err != nil {
panic(err)
}
allFuncNames := make([]string, 0, 32)
ast.Inspect(f, func(n ast.Node) bool {
x, ok := n.(*ast.TypeSpec) // get all type definitions
if !ok {
return true
}
typeName := x.Name.Name
if !strings.HasPrefix(typeName, "builtin") ||
!strings.HasSuffix(typeName, "Sig") {
return true // the type name should be "builtin*Sig"
}
if x.Type == nil {
return true
}
structType, ok := x.Type.(*ast.StructType)
if !ok { // the type must be a structure
return true
}
allFuncNames = append(allFuncNames, typeName)
if _, ok := specialSafeFuncs[typeName]; ok {
safeFuncNames = append(safeFuncNames, typeName)
return true
}
if len(structType.Fields.List) != 1 { // this structure only has 1 field
return true
}
// this builtinXSig has only 1 field and this field is `baseBuiltinFunc` or `baseBuiltinCastFunc`.
if ident, ok := structType.Fields.List[0].Type.(*ast.Ident); ok &&
(ident.Name == "baseBuiltinFunc" || ident.Name == "baseBuiltinCastFunc") {
safeFuncNames = append(safeFuncNames, typeName)
}
return true
})
safeFuncMap := make(map[string]struct{}, len(safeFuncNames))
for _, name := range safeFuncNames {
safeFuncMap[name] = struct{}{}
}
for _, fName := range allFuncNames {
if _, ok := safeFuncMap[fName]; !ok {
unsafeFuncNames = append(unsafeFuncNames, fName)
}
}
return safeFuncNames, unsafeFuncNames
}
func genBuiltinThreadSafeCode(exprCodeDir string) (safe, unsafe []byte) {
entries, err := os.ReadDir(exprCodeDir)
if err != nil {
panic(err)
}
files := make([]string, 0, 16)
for _, entry := range entries {
if entry.IsDir() {
continue
}
if strings.HasPrefix(entry.Name(), "builtin_") &&
strings.HasSuffix(entry.Name(), ".go") &&
!strings.Contains(entry.Name(), "_test") {
files = append(files, entry.Name())
}
}
sort.Strings(files)
safeFuncs := make([]string, 0, 32)
unsafeFuncs := make([]string, 0, 32)
for _, file := range files {
safeNames, unsafeNames := collectThreadSafeBuiltinFuncs(path.Join(exprCodeDir, file))
safeFuncs = append(safeFuncs, safeNames...)
unsafeFuncs = append(unsafeFuncs, unsafeNames...)
}
sort.Strings(safeFuncs)
formattedSafe, err := generateCode(safeFuncs, safeHeader, safeFuncTemp)
if err != nil {
panic(err)
}
formattedUnsafe, err := generateCode(unsafeFuncs, unsafeHeader, unsafeFuncTemp)
if err != nil {
panic(err)
}
return formattedSafe, formattedUnsafe
}
func generateCode(funcNames []string, header, template string) ([]byte, error) {
var buffer bytes.Buffer
buffer.WriteString(header)
for _, funcName := range funcNames {
buffer.WriteString(fmt.Sprintf(template, funcName))
}
return format.Source(buffer.Bytes())
}
func main() {
safeCode, unsafeCode := genBuiltinThreadSafeCode(".")
if err := os.WriteFile("./builtin_threadsafe_generated.go", safeCode, 0644); err != nil {
log.Fatalln("failed to write builtin_threadsafe_generated.go", err)
}
if err := os.WriteFile("./builtin_threadunsafe_generated.go", unsafeCode, 0644); err != nil {
log.Fatalln("failed to write builtin_threadunsafe_generated.go", err)
}
}
const (
safeFuncTemp = `// SafeToShareAcrossSession implements BuiltinFunc.SafeToShareAcrossSession.
func (s *%s) SafeToShareAcrossSession() bool {
return safeToShareAcrossSession(&s.safeToShareAcrossSessionFlag, s.args)
}
`
unsafeFuncTemp = `// SafeToShareAcrossSession implements BuiltinFunc.SafeToShareAcrossSession.
func (s *%s) SafeToShareAcrossSession() bool {
return false
}
`
safeHeader = `// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by go generate in expression/generator; DO NOT EDIT.
package expression
import "sync/atomic"
func safeToShareAcrossSession(flag *uint32, args []Expression) bool {
flagV := atomic.LoadUint32(flag)
if flagV != 0 {
return flagV == 1
}
allArgsSafe := true
for _, arg := range args {
if !arg.SafeToShareAcrossSession() {
allArgsSafe = false
break
}
}
if allArgsSafe {
atomic.StoreUint32(flag, 1)
} else {
atomic.StoreUint32(flag, 2)
}
return allArgsSafe
}
`
unsafeHeader = `// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by go generate in expression/generator; DO NOT EDIT.
package expression
`
)