1171 lines
33 KiB
Go
1171 lines
33 KiB
Go
// Copyright 2015 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package ast
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pingcap/errors"
|
|
"github.com/pingcap/tidb/pkg/parser/format"
|
|
"github.com/pingcap/tidb/pkg/parser/model"
|
|
"github.com/pingcap/tidb/pkg/parser/types"
|
|
)
|
|
|
|
var (
|
|
_ FuncNode = &AggregateFuncExpr{}
|
|
_ FuncNode = &FuncCallExpr{}
|
|
_ FuncNode = &FuncCastExpr{}
|
|
_ FuncNode = &WindowFuncExpr{}
|
|
)
|
|
|
|
// List scalar function names.
|
|
const (
|
|
LogicAnd = "and"
|
|
Cast = "cast"
|
|
LeftShift = "leftshift"
|
|
RightShift = "rightshift"
|
|
LogicOr = "or"
|
|
GE = "ge"
|
|
LE = "le"
|
|
EQ = "eq"
|
|
NE = "ne"
|
|
LT = "lt"
|
|
GT = "gt"
|
|
Plus = "plus"
|
|
Minus = "minus"
|
|
And = "bitand"
|
|
Or = "bitor"
|
|
Mod = "mod"
|
|
Xor = "bitxor"
|
|
Div = "div"
|
|
Mul = "mul"
|
|
UnaryNot = "not" // Avoid name conflict with Not in github/pingcap/check.
|
|
BitNeg = "bitneg"
|
|
IntDiv = "intdiv"
|
|
LogicXor = "xor"
|
|
NullEQ = "nulleq"
|
|
UnaryPlus = "unaryplus"
|
|
UnaryMinus = "unaryminus"
|
|
In = "in"
|
|
Like = "like"
|
|
Ilike = "ilike"
|
|
Case = "case"
|
|
Regexp = "regexp"
|
|
RegexpLike = "regexp_like"
|
|
RegexpSubstr = "regexp_substr"
|
|
RegexpInStr = "regexp_instr"
|
|
RegexpReplace = "regexp_replace"
|
|
IsNull = "isnull"
|
|
IsTruthWithoutNull = "istrue" // Avoid name conflict with IsTrue in github/pingcap/check.
|
|
IsTruthWithNull = "istrue_with_null"
|
|
IsFalsity = "isfalse" // Avoid name conflict with IsFalse in github/pingcap/check.
|
|
RowFunc = "row"
|
|
SetVar = "setvar"
|
|
GetVar = "getvar"
|
|
Values = "values"
|
|
BitCount = "bit_count"
|
|
GetParam = "getparam"
|
|
|
|
// common functions
|
|
Coalesce = "coalesce"
|
|
Greatest = "greatest"
|
|
Least = "least"
|
|
Interval = "interval"
|
|
|
|
// math functions
|
|
Abs = "abs"
|
|
Acos = "acos"
|
|
Asin = "asin"
|
|
Atan = "atan"
|
|
Atan2 = "atan2"
|
|
Ceil = "ceil"
|
|
Ceiling = "ceiling"
|
|
Conv = "conv"
|
|
Cos = "cos"
|
|
Cot = "cot"
|
|
CRC32 = "crc32"
|
|
Degrees = "degrees"
|
|
Exp = "exp"
|
|
Floor = "floor"
|
|
Ln = "ln"
|
|
Log = "log"
|
|
Log2 = "log2"
|
|
Log10 = "log10"
|
|
PI = "pi"
|
|
Pow = "pow"
|
|
Power = "power"
|
|
Radians = "radians"
|
|
Rand = "rand"
|
|
Round = "round"
|
|
Sign = "sign"
|
|
Sin = "sin"
|
|
Sqrt = "sqrt"
|
|
Tan = "tan"
|
|
Truncate = "truncate"
|
|
|
|
// time functions
|
|
AddDate = "adddate"
|
|
AddTime = "addtime"
|
|
ConvertTz = "convert_tz"
|
|
Curdate = "curdate"
|
|
CurrentDate = "current_date"
|
|
CurrentTime = "current_time"
|
|
CurrentTimestamp = "current_timestamp"
|
|
Curtime = "curtime"
|
|
Date = "date"
|
|
DateLiteral = "'tidb`.(dateliteral"
|
|
DateAdd = "date_add"
|
|
DateFormat = "date_format"
|
|
DateSub = "date_sub"
|
|
DateDiff = "datediff"
|
|
Day = "day"
|
|
DayName = "dayname"
|
|
DayOfMonth = "dayofmonth"
|
|
DayOfWeek = "dayofweek"
|
|
DayOfYear = "dayofyear"
|
|
Extract = "extract"
|
|
FromDays = "from_days"
|
|
FromUnixTime = "from_unixtime"
|
|
GetFormat = "get_format"
|
|
Hour = "hour"
|
|
LocalTime = "localtime"
|
|
LocalTimestamp = "localtimestamp"
|
|
MakeDate = "makedate"
|
|
MakeTime = "maketime"
|
|
MicroSecond = "microsecond"
|
|
Minute = "minute"
|
|
Month = "month"
|
|
MonthName = "monthname"
|
|
Now = "now"
|
|
PeriodAdd = "period_add"
|
|
PeriodDiff = "period_diff"
|
|
Quarter = "quarter"
|
|
SecToTime = "sec_to_time"
|
|
Second = "second"
|
|
StrToDate = "str_to_date"
|
|
SubDate = "subdate"
|
|
SubTime = "subtime"
|
|
Sysdate = "sysdate"
|
|
Time = "time"
|
|
TimeLiteral = "'tidb`.(timeliteral"
|
|
TimeFormat = "time_format"
|
|
TimeToSec = "time_to_sec"
|
|
TimeDiff = "timediff"
|
|
Timestamp = "timestamp"
|
|
TimestampLiteral = "'tidb`.(timestampliteral"
|
|
TimestampAdd = "timestampadd"
|
|
TimestampDiff = "timestampdiff"
|
|
ToDays = "to_days"
|
|
ToSeconds = "to_seconds"
|
|
UnixTimestamp = "unix_timestamp"
|
|
UTCDate = "utc_date"
|
|
UTCTime = "utc_time"
|
|
UTCTimestamp = "utc_timestamp"
|
|
Week = "week"
|
|
Weekday = "weekday"
|
|
WeekOfYear = "weekofyear"
|
|
Year = "year"
|
|
YearWeek = "yearweek"
|
|
LastDay = "last_day"
|
|
// TSO functions
|
|
// TiDBBoundedStaleness is used to determine the TS for a read only request with the given bounded staleness.
|
|
// It will be used in the Stale Read feature.
|
|
// For more info, please see AsOfClause.
|
|
TiDBBoundedStaleness = "tidb_bounded_staleness"
|
|
TiDBParseTso = "tidb_parse_tso"
|
|
TiDBParseTsoLogical = "tidb_parse_tso_logical"
|
|
TiDBCurrentTso = "tidb_current_tso"
|
|
|
|
// string functions
|
|
ASCII = "ascii"
|
|
Bin = "bin"
|
|
Concat = "concat"
|
|
ConcatWS = "concat_ws"
|
|
Convert = "convert"
|
|
Elt = "elt"
|
|
ExportSet = "export_set"
|
|
Field = "field"
|
|
Format = "format"
|
|
FromBase64 = "from_base64"
|
|
InsertFunc = "insert_func"
|
|
Instr = "instr"
|
|
Lcase = "lcase"
|
|
Left = "left"
|
|
Length = "length"
|
|
LoadFile = "load_file"
|
|
Locate = "locate"
|
|
Lower = "lower"
|
|
Lpad = "lpad"
|
|
LTrim = "ltrim"
|
|
MakeSet = "make_set"
|
|
Mid = "mid"
|
|
Oct = "oct"
|
|
OctetLength = "octet_length"
|
|
Ord = "ord"
|
|
Position = "position"
|
|
Quote = "quote"
|
|
Repeat = "repeat"
|
|
Replace = "replace"
|
|
Reverse = "reverse"
|
|
Right = "right"
|
|
RTrim = "rtrim"
|
|
Space = "space"
|
|
Strcmp = "strcmp"
|
|
Substring = "substring"
|
|
Substr = "substr"
|
|
SubstringIndex = "substring_index"
|
|
ToBase64 = "to_base64"
|
|
Trim = "trim"
|
|
Translate = "translate"
|
|
Upper = "upper"
|
|
Ucase = "ucase"
|
|
Hex = "hex"
|
|
Unhex = "unhex"
|
|
Rpad = "rpad"
|
|
BitLength = "bit_length"
|
|
CharFunc = "char_func"
|
|
CharLength = "char_length"
|
|
CharacterLength = "character_length"
|
|
FindInSet = "find_in_set"
|
|
WeightString = "weight_string"
|
|
Soundex = "soundex"
|
|
|
|
// information functions
|
|
Benchmark = "benchmark"
|
|
Charset = "charset"
|
|
Coercibility = "coercibility"
|
|
Collation = "collation"
|
|
ConnectionID = "connection_id"
|
|
CurrentUser = "current_user"
|
|
CurrentRole = "current_role"
|
|
Database = "database"
|
|
FoundRows = "found_rows"
|
|
LastInsertId = "last_insert_id"
|
|
RowCount = "row_count"
|
|
Schema = "schema"
|
|
SessionUser = "session_user"
|
|
SystemUser = "system_user"
|
|
User = "user"
|
|
Version = "version"
|
|
TiDBVersion = "tidb_version"
|
|
TiDBIsDDLOwner = "tidb_is_ddl_owner"
|
|
TiDBDecodePlan = "tidb_decode_plan"
|
|
TiDBDecodeBinaryPlan = "tidb_decode_binary_plan"
|
|
TiDBDecodeSQLDigests = "tidb_decode_sql_digests"
|
|
TiDBEncodeSQLDigest = "tidb_encode_sql_digest"
|
|
FormatBytes = "format_bytes"
|
|
FormatNanoTime = "format_nano_time"
|
|
CurrentResourceGroup = "current_resource_group"
|
|
|
|
// control functions
|
|
If = "if"
|
|
Ifnull = "ifnull"
|
|
Nullif = "nullif"
|
|
|
|
// miscellaneous functions
|
|
AnyValue = "any_value"
|
|
DefaultFunc = "default_func"
|
|
InetAton = "inet_aton"
|
|
InetNtoa = "inet_ntoa"
|
|
Inet6Aton = "inet6_aton"
|
|
Inet6Ntoa = "inet6_ntoa"
|
|
IsFreeLock = "is_free_lock"
|
|
IsIPv4 = "is_ipv4"
|
|
IsIPv4Compat = "is_ipv4_compat"
|
|
IsIPv4Mapped = "is_ipv4_mapped"
|
|
IsIPv6 = "is_ipv6"
|
|
IsUsedLock = "is_used_lock"
|
|
IsUUID = "is_uuid"
|
|
MasterPosWait = "master_pos_wait"
|
|
NameConst = "name_const"
|
|
ReleaseAllLocks = "release_all_locks"
|
|
Sleep = "sleep"
|
|
UUID = "uuid"
|
|
UUIDShort = "uuid_short"
|
|
UUIDToBin = "uuid_to_bin"
|
|
BinToUUID = "bin_to_uuid"
|
|
VitessHash = "vitess_hash"
|
|
TiDBShard = "tidb_shard"
|
|
TiDBRowChecksum = "tidb_row_checksum"
|
|
GetLock = "get_lock"
|
|
ReleaseLock = "release_lock"
|
|
Grouping = "grouping"
|
|
|
|
// encryption and compression functions
|
|
AesDecrypt = "aes_decrypt"
|
|
AesEncrypt = "aes_encrypt"
|
|
Compress = "compress"
|
|
Decode = "decode"
|
|
DesDecrypt = "des_decrypt"
|
|
DesEncrypt = "des_encrypt"
|
|
Encode = "encode"
|
|
Encrypt = "encrypt"
|
|
MD5 = "md5"
|
|
OldPassword = "old_password"
|
|
PasswordFunc = "password_func"
|
|
RandomBytes = "random_bytes"
|
|
SHA1 = "sha1"
|
|
SHA = "sha"
|
|
SHA2 = "sha2"
|
|
SM3 = "sm3"
|
|
Uncompress = "uncompress"
|
|
UncompressedLength = "uncompressed_length"
|
|
ValidatePasswordStrength = "validate_password_strength"
|
|
|
|
// json functions
|
|
JSONType = "json_type"
|
|
JSONExtract = "json_extract"
|
|
JSONUnquote = "json_unquote"
|
|
JSONArray = "json_array"
|
|
JSONObject = "json_object"
|
|
JSONMerge = "json_merge"
|
|
JSONSet = "json_set"
|
|
JSONInsert = "json_insert"
|
|
JSONReplace = "json_replace"
|
|
JSONRemove = "json_remove"
|
|
JSONOverlaps = "json_overlaps"
|
|
JSONContains = "json_contains"
|
|
JSONMemberOf = "json_memberof"
|
|
JSONContainsPath = "json_contains_path"
|
|
JSONValid = "json_valid"
|
|
JSONArrayAppend = "json_array_append"
|
|
JSONArrayInsert = "json_array_insert"
|
|
JSONMergePatch = "json_merge_patch"
|
|
JSONMergePreserve = "json_merge_preserve"
|
|
JSONPretty = "json_pretty"
|
|
JSONQuote = "json_quote"
|
|
JSONSearch = "json_search"
|
|
JSONStorageFree = "json_storage_free"
|
|
JSONStorageSize = "json_storage_size"
|
|
JSONDepth = "json_depth"
|
|
JSONKeys = "json_keys"
|
|
JSONLength = "json_length"
|
|
|
|
// TiDB internal function.
|
|
TiDBDecodeKey = "tidb_decode_key"
|
|
TiDBDecodeBase64Key = "tidb_decode_base64_key"
|
|
|
|
// MVCC information fetching function.
|
|
GetMvccInfo = "get_mvcc_info"
|
|
|
|
// Sequence function.
|
|
NextVal = "nextval"
|
|
LastVal = "lastval"
|
|
SetVal = "setval"
|
|
)
|
|
|
|
type FuncCallExprType int8
|
|
|
|
const (
|
|
FuncCallExprTypeKeyword FuncCallExprType = iota
|
|
FuncCallExprTypeGeneric
|
|
)
|
|
|
|
// FuncCallExpr is for function expression.
|
|
type FuncCallExpr struct {
|
|
funcNode
|
|
Tp FuncCallExprType
|
|
Schema model.CIStr
|
|
// FnName is the function name.
|
|
FnName model.CIStr
|
|
// Args is the function args.
|
|
Args []ExprNode
|
|
}
|
|
|
|
// Restore implements Node interface.
|
|
func (n *FuncCallExpr) Restore(ctx *format.RestoreCtx) error {
|
|
done, err := n.customRestore(ctx)
|
|
if done {
|
|
return err
|
|
}
|
|
|
|
if len(n.Schema.String()) != 0 {
|
|
ctx.WriteName(n.Schema.O)
|
|
ctx.WritePlain(".")
|
|
}
|
|
if n.Tp == FuncCallExprTypeGeneric {
|
|
ctx.WriteName(n.FnName.O)
|
|
} else {
|
|
ctx.WriteKeyWord(n.FnName.O)
|
|
}
|
|
|
|
ctx.WritePlain("(")
|
|
switch n.FnName.L {
|
|
case "convert":
|
|
if err := n.Args[0].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCastExpr.Expr")
|
|
}
|
|
ctx.WriteKeyWord(" USING ")
|
|
if err := n.Args[1].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCastExpr.Expr")
|
|
}
|
|
case "adddate", "subdate", "date_add", "date_sub":
|
|
if err := n.Args[0].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.Args[0]")
|
|
}
|
|
ctx.WritePlain(", ")
|
|
ctx.WriteKeyWord("INTERVAL ")
|
|
if err := n.Args[1].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.Args[1]")
|
|
}
|
|
ctx.WritePlain(" ")
|
|
if err := n.Args[2].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.Args[2]")
|
|
}
|
|
case "extract":
|
|
if err := n.Args[0].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.Args[0]")
|
|
}
|
|
ctx.WriteKeyWord(" FROM ")
|
|
if err := n.Args[1].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.Args[1]")
|
|
}
|
|
case "position":
|
|
if err := n.Args[0].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
|
|
}
|
|
ctx.WriteKeyWord(" IN ")
|
|
if err := n.Args[1].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
|
|
}
|
|
case "trim":
|
|
switch len(n.Args) {
|
|
case 3:
|
|
if err := n.Args[2].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.Args[2]")
|
|
}
|
|
ctx.WritePlain(" ")
|
|
fallthrough
|
|
case 2:
|
|
if expr, isValue := n.Args[1].(ValueExpr); !isValue || expr.GetValue() != nil {
|
|
if err := n.Args[1].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.Args[1]")
|
|
}
|
|
ctx.WritePlain(" ")
|
|
}
|
|
ctx.WriteKeyWord("FROM ")
|
|
fallthrough
|
|
case 1:
|
|
if err := n.Args[0].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.Args[0]")
|
|
}
|
|
}
|
|
case WeightString:
|
|
if err := n.Args[0].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.(WEIGHT_STRING).Args[0]")
|
|
}
|
|
if len(n.Args) == 3 {
|
|
ctx.WriteKeyWord(" AS ")
|
|
ctx.WriteKeyWord(n.Args[1].(ValueExpr).GetValue().(string))
|
|
ctx.WritePlain("(")
|
|
if err := n.Args[2].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.(WEIGHT_STRING).Args[2]")
|
|
}
|
|
ctx.WritePlain(")")
|
|
}
|
|
default:
|
|
for i, argv := range n.Args {
|
|
if i != 0 {
|
|
ctx.WritePlain(", ")
|
|
}
|
|
if err := argv.Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.Args %d", i)
|
|
}
|
|
}
|
|
}
|
|
ctx.WritePlain(")")
|
|
return nil
|
|
}
|
|
|
|
func (n *FuncCallExpr) customRestore(ctx *format.RestoreCtx) (bool, error) {
|
|
var specialLiteral string
|
|
switch n.FnName.L {
|
|
case DateLiteral:
|
|
specialLiteral = "DATE "
|
|
case TimeLiteral:
|
|
specialLiteral = "TIME "
|
|
case TimestampLiteral:
|
|
specialLiteral = "TIMESTAMP "
|
|
}
|
|
if specialLiteral != "" {
|
|
ctx.WritePlain(specialLiteral)
|
|
if err := n.Args[0].Restore(ctx); err != nil {
|
|
return true, errors.Annotatef(err, "An error occurred while restore FuncCallExpr.Expr")
|
|
}
|
|
return true, nil
|
|
}
|
|
if n.FnName.L == JSONMemberOf {
|
|
if err := n.Args[0].Restore(ctx); err != nil {
|
|
return true, errors.Annotatef(err, "An error occurred while restore FuncCallExpr.(MEMBER OF).Args[0]")
|
|
}
|
|
ctx.WriteKeyWord(" MEMBER OF ")
|
|
ctx.WritePlain("(")
|
|
if err := n.Args[1].Restore(ctx); err != nil {
|
|
return true, errors.Annotatef(err, "An error occurred while restore FuncCallExpr.(MEMBER OF).Args[1]")
|
|
}
|
|
ctx.WritePlain(")")
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
// Format the ExprNode into a Writer.
|
|
func (n *FuncCallExpr) Format(w io.Writer) {
|
|
if !n.specialFormatArgs(w) {
|
|
fmt.Fprintf(w, "%s(", n.FnName.L)
|
|
for i, arg := range n.Args {
|
|
arg.Format(w)
|
|
if i != len(n.Args)-1 {
|
|
fmt.Fprint(w, ", ")
|
|
}
|
|
}
|
|
fmt.Fprint(w, ")")
|
|
}
|
|
}
|
|
|
|
// specialFormatArgs formats argument list for some special functions.
|
|
func (n *FuncCallExpr) specialFormatArgs(w io.Writer) bool {
|
|
switch n.FnName.L {
|
|
case DateAdd, DateSub, AddDate, SubDate:
|
|
fmt.Fprintf(w, "%s(", n.FnName.L)
|
|
n.Args[0].Format(w)
|
|
fmt.Fprint(w, ", INTERVAL ")
|
|
n.Args[1].Format(w)
|
|
fmt.Fprint(w, " ")
|
|
n.Args[2].Format(w)
|
|
fmt.Fprint(w, ")")
|
|
return true
|
|
case JSONMemberOf:
|
|
n.Args[0].Format(w)
|
|
fmt.Fprint(w, " MEMBER OF ")
|
|
fmt.Fprint(w, " (")
|
|
n.Args[1].Format(w)
|
|
fmt.Fprint(w, ")")
|
|
return true
|
|
case Extract:
|
|
fmt.Fprintf(w, "%s(", n.FnName.L)
|
|
n.Args[0].Format(w)
|
|
fmt.Fprint(w, " FROM ")
|
|
n.Args[1].Format(w)
|
|
fmt.Fprint(w, ")")
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Accept implements Node interface.
|
|
func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) {
|
|
newNode, skipChildren := v.Enter(n)
|
|
if skipChildren {
|
|
return v.Leave(newNode)
|
|
}
|
|
n = newNode.(*FuncCallExpr)
|
|
for i, val := range n.Args {
|
|
node, ok := val.Accept(v)
|
|
if !ok {
|
|
return n, false
|
|
}
|
|
n.Args[i] = node.(ExprNode)
|
|
}
|
|
return v.Leave(n)
|
|
}
|
|
|
|
// CastFunctionType is the type for cast function.
|
|
type CastFunctionType int
|
|
|
|
// CastFunction types
|
|
const (
|
|
CastFunction CastFunctionType = iota + 1
|
|
CastConvertFunction
|
|
CastBinaryOperator
|
|
)
|
|
|
|
// FuncCastExpr is the cast function converting value to another type, e.g, cast(expr AS signed).
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html
|
|
type FuncCastExpr struct {
|
|
funcNode
|
|
// Expr is the expression to be converted.
|
|
Expr ExprNode
|
|
// Tp is the conversion type.
|
|
Tp *types.FieldType
|
|
// FunctionType is either Cast, Convert or Binary.
|
|
FunctionType CastFunctionType
|
|
// ExplicitCharSet is true when charset is explicit indicated.
|
|
ExplicitCharSet bool
|
|
}
|
|
|
|
// Restore implements Node interface.
|
|
func (n *FuncCastExpr) Restore(ctx *format.RestoreCtx) error {
|
|
switch n.FunctionType {
|
|
case CastFunction:
|
|
ctx.WriteKeyWord("CAST")
|
|
ctx.WritePlain("(")
|
|
if err := n.Expr.Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCastExpr.Expr")
|
|
}
|
|
ctx.WriteKeyWord(" AS ")
|
|
n.Tp.RestoreAsCastType(ctx, n.ExplicitCharSet)
|
|
ctx.WritePlain(")")
|
|
case CastConvertFunction:
|
|
ctx.WriteKeyWord("CONVERT")
|
|
ctx.WritePlain("(")
|
|
if err := n.Expr.Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCastExpr.Expr")
|
|
}
|
|
ctx.WritePlain(", ")
|
|
n.Tp.RestoreAsCastType(ctx, n.ExplicitCharSet)
|
|
ctx.WritePlain(")")
|
|
case CastBinaryOperator:
|
|
ctx.WriteKeyWord("BINARY ")
|
|
if err := n.Expr.Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore FuncCastExpr.Expr")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Format the ExprNode into a Writer.
|
|
func (n *FuncCastExpr) Format(w io.Writer) {
|
|
switch n.FunctionType {
|
|
case CastFunction:
|
|
fmt.Fprint(w, "CAST(")
|
|
n.Expr.Format(w)
|
|
fmt.Fprint(w, " AS ")
|
|
n.Tp.FormatAsCastType(w, n.ExplicitCharSet)
|
|
fmt.Fprint(w, ")")
|
|
case CastConvertFunction:
|
|
fmt.Fprint(w, "CONVERT(")
|
|
n.Expr.Format(w)
|
|
fmt.Fprint(w, ", ")
|
|
n.Tp.FormatAsCastType(w, n.ExplicitCharSet)
|
|
fmt.Fprint(w, ")")
|
|
case CastBinaryOperator:
|
|
fmt.Fprint(w, "BINARY ")
|
|
n.Expr.Format(w)
|
|
}
|
|
}
|
|
|
|
// Accept implements Node Accept interface.
|
|
func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) {
|
|
newNode, skipChildren := v.Enter(n)
|
|
if skipChildren {
|
|
return v.Leave(newNode)
|
|
}
|
|
n = newNode.(*FuncCastExpr)
|
|
node, ok := n.Expr.Accept(v)
|
|
if !ok {
|
|
return n, false
|
|
}
|
|
n.Expr = node.(ExprNode)
|
|
return v.Leave(n)
|
|
}
|
|
|
|
// TrimDirectionType is the type for trim direction.
|
|
type TrimDirectionType int
|
|
|
|
const (
|
|
// TrimBothDefault trims from both direction by default.
|
|
TrimBothDefault TrimDirectionType = iota
|
|
// TrimBoth trims from both direction with explicit notation.
|
|
TrimBoth
|
|
// TrimLeading trims from left.
|
|
TrimLeading
|
|
// TrimTrailing trims from right.
|
|
TrimTrailing
|
|
)
|
|
|
|
// String implements fmt.Stringer interface.
|
|
func (direction TrimDirectionType) String() string {
|
|
switch direction {
|
|
case TrimBoth, TrimBothDefault:
|
|
return "BOTH"
|
|
case TrimLeading:
|
|
return "LEADING"
|
|
case TrimTrailing:
|
|
return "TRAILING"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// TrimDirectionExpr is an expression representing the trim direction used in the TRIM() function.
|
|
type TrimDirectionExpr struct {
|
|
exprNode
|
|
// Direction is the trim direction
|
|
Direction TrimDirectionType
|
|
}
|
|
|
|
// Restore implements Node interface.
|
|
func (n *TrimDirectionExpr) Restore(ctx *format.RestoreCtx) error {
|
|
ctx.WriteKeyWord(n.Direction.String())
|
|
return nil
|
|
}
|
|
|
|
// Format the ExprNode into a Writer.
|
|
func (n *TrimDirectionExpr) Format(w io.Writer) {
|
|
fmt.Fprint(w, n.Direction.String())
|
|
}
|
|
|
|
// Accept implements Node Accept interface.
|
|
func (n *TrimDirectionExpr) Accept(v Visitor) (Node, bool) {
|
|
newNode, skipChildren := v.Enter(n)
|
|
if skipChildren {
|
|
return v.Leave(newNode)
|
|
}
|
|
return v.Leave(n)
|
|
}
|
|
|
|
// DateArithType is type for DateArith type.
|
|
type DateArithType byte
|
|
|
|
const (
|
|
// DateArithAdd is to run adddate or date_add function option.
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
|
|
DateArithAdd DateArithType = iota + 1
|
|
// DateArithSub is to run subdate or date_sub function option.
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub
|
|
DateArithSub
|
|
)
|
|
|
|
const (
|
|
// AggFuncCount is the name of Count function.
|
|
AggFuncCount = "count"
|
|
// AggFuncSum is the name of Sum function.
|
|
AggFuncSum = "sum"
|
|
// AggFuncAvg is the name of Avg function.
|
|
AggFuncAvg = "avg"
|
|
// AggFuncFirstRow is the name of FirstRowColumn function.
|
|
AggFuncFirstRow = "firstrow"
|
|
// AggFuncMax is the name of max function.
|
|
AggFuncMax = "max"
|
|
// AggFuncMin is the name of min function.
|
|
AggFuncMin = "min"
|
|
// AggFuncGroupConcat is the name of group_concat function.
|
|
AggFuncGroupConcat = "group_concat"
|
|
// AggFuncBitOr is the name of bit_or function.
|
|
AggFuncBitOr = "bit_or"
|
|
// AggFuncBitXor is the name of bit_xor function.
|
|
AggFuncBitXor = "bit_xor"
|
|
// AggFuncBitAnd is the name of bit_and function.
|
|
AggFuncBitAnd = "bit_and"
|
|
// AggFuncVarPop is the name of var_pop function
|
|
AggFuncVarPop = "var_pop"
|
|
// AggFuncVarSamp is the name of var_samp function
|
|
AggFuncVarSamp = "var_samp"
|
|
// AggFuncStddevPop is the name of stddev_pop/std/stddev function
|
|
AggFuncStddevPop = "stddev_pop"
|
|
// AggFuncStddevSamp is the name of stddev_samp function
|
|
AggFuncStddevSamp = "stddev_samp"
|
|
// AggFuncJsonArrayagg is the name of json_arrayagg function
|
|
AggFuncJsonArrayagg = "json_arrayagg"
|
|
// AggFuncJsonObjectAgg is the name of json_objectagg function
|
|
AggFuncJsonObjectAgg = "json_objectagg"
|
|
// AggFuncApproxCountDistinct is the name of approx_count_distinct function.
|
|
AggFuncApproxCountDistinct = "approx_count_distinct"
|
|
// AggFuncApproxPercentile is the name of approx_percentile function.
|
|
AggFuncApproxPercentile = "approx_percentile"
|
|
)
|
|
|
|
// AggregateFuncExpr represents aggregate function expression.
|
|
type AggregateFuncExpr struct {
|
|
funcNode
|
|
// F is the function name.
|
|
F string
|
|
// Args is the function args.
|
|
Args []ExprNode
|
|
// Distinct is true, function hence only aggregate distinct values.
|
|
// For example, column c1 values are "1", "2", "2", "sum(c1)" is "5",
|
|
// but "sum(distinct c1)" is "3".
|
|
Distinct bool
|
|
// Order is only used in GROUP_CONCAT
|
|
Order *OrderByClause
|
|
}
|
|
|
|
// Restore implements Node interface.
|
|
func (n *AggregateFuncExpr) Restore(ctx *format.RestoreCtx) error {
|
|
ctx.WriteKeyWord(n.F)
|
|
ctx.WritePlain("(")
|
|
if n.Distinct {
|
|
ctx.WriteKeyWord("DISTINCT ")
|
|
}
|
|
switch strings.ToLower(n.F) {
|
|
case "group_concat":
|
|
for i := 0; i < len(n.Args)-1; i++ {
|
|
if i != 0 {
|
|
ctx.WritePlain(", ")
|
|
}
|
|
if err := n.Args[i].Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore AggregateFuncExpr.Args[%d]", i)
|
|
}
|
|
}
|
|
if n.Order != nil {
|
|
ctx.WritePlain(" ")
|
|
if err := n.Order.Restore(ctx); err != nil {
|
|
return errors.Annotate(err, "An error occur while restore AggregateFuncExpr.Args Order")
|
|
}
|
|
}
|
|
ctx.WriteKeyWord(" SEPARATOR ")
|
|
if err := n.Args[len(n.Args)-1].Restore(ctx); err != nil {
|
|
return errors.Annotate(err, "An error occurred while restore AggregateFuncExpr.Args SEPARATOR")
|
|
}
|
|
default:
|
|
for i, argv := range n.Args {
|
|
if i != 0 {
|
|
ctx.WritePlain(", ")
|
|
}
|
|
if err := argv.Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore AggregateFuncExpr.Args[%d]", i)
|
|
}
|
|
}
|
|
}
|
|
ctx.WritePlain(")")
|
|
return nil
|
|
}
|
|
|
|
// Format the ExprNode into a Writer.
|
|
func (n *AggregateFuncExpr) Format(w io.Writer) {
|
|
panic("Not implemented")
|
|
}
|
|
|
|
// Accept implements Node Accept interface.
|
|
func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) {
|
|
newNode, skipChildren := v.Enter(n)
|
|
if skipChildren {
|
|
return v.Leave(newNode)
|
|
}
|
|
n = newNode.(*AggregateFuncExpr)
|
|
for i, val := range n.Args {
|
|
node, ok := val.Accept(v)
|
|
if !ok {
|
|
return n, false
|
|
}
|
|
n.Args[i] = node.(ExprNode)
|
|
}
|
|
if n.Order != nil {
|
|
node, ok := n.Order.Accept(v)
|
|
if !ok {
|
|
return n, false
|
|
}
|
|
n.Order = node.(*OrderByClause)
|
|
}
|
|
return v.Leave(n)
|
|
}
|
|
|
|
const (
|
|
// WindowFuncRowNumber is the name of row_number function.
|
|
WindowFuncRowNumber = "row_number"
|
|
// WindowFuncRank is the name of rank function.
|
|
WindowFuncRank = "rank"
|
|
// WindowFuncDenseRank is the name of dense_rank function.
|
|
WindowFuncDenseRank = "dense_rank"
|
|
// WindowFuncCumeDist is the name of cume_dist function.
|
|
WindowFuncCumeDist = "cume_dist"
|
|
// WindowFuncPercentRank is the name of percent_rank function.
|
|
WindowFuncPercentRank = "percent_rank"
|
|
// WindowFuncNtile is the name of ntile function.
|
|
WindowFuncNtile = "ntile"
|
|
// WindowFuncLead is the name of lead function.
|
|
WindowFuncLead = "lead"
|
|
// WindowFuncLag is the name of lag function.
|
|
WindowFuncLag = "lag"
|
|
// WindowFuncFirstValue is the name of first_value function.
|
|
WindowFuncFirstValue = "first_value"
|
|
// WindowFuncLastValue is the name of last_value function.
|
|
WindowFuncLastValue = "last_value"
|
|
// WindowFuncNthValue is the name of nth_value function.
|
|
WindowFuncNthValue = "nth_value"
|
|
)
|
|
|
|
// WindowFuncExpr represents window function expression.
|
|
type WindowFuncExpr struct {
|
|
funcNode
|
|
|
|
// Name is the function name.
|
|
Name string
|
|
// Args is the function args.
|
|
Args []ExprNode
|
|
// Distinct cannot be true for most window functions, except `max` and `min`.
|
|
// We need to raise error if it is not allowed to be true.
|
|
Distinct bool
|
|
// IgnoreNull indicates how to handle null value.
|
|
// MySQL only supports `RESPECT NULLS`, so we need to raise error if it is true.
|
|
IgnoreNull bool
|
|
// FromLast indicates the calculation direction of this window function.
|
|
// MySQL only supports calculation from first, so we need to raise error if it is true.
|
|
FromLast bool
|
|
// Spec is the specification of this window.
|
|
Spec WindowSpec
|
|
}
|
|
|
|
// Restore implements Node interface.
|
|
func (n *WindowFuncExpr) Restore(ctx *format.RestoreCtx) error {
|
|
ctx.WriteKeyWord(n.Name)
|
|
ctx.WritePlain("(")
|
|
for i, v := range n.Args {
|
|
if i != 0 {
|
|
ctx.WritePlain(", ")
|
|
} else if n.Distinct {
|
|
ctx.WriteKeyWord("DISTINCT ")
|
|
}
|
|
if err := v.Restore(ctx); err != nil {
|
|
return errors.Annotatef(err, "An error occurred while restore WindowFuncExpr.Args[%d]", i)
|
|
}
|
|
}
|
|
ctx.WritePlain(")")
|
|
if n.FromLast {
|
|
ctx.WriteKeyWord(" FROM LAST")
|
|
}
|
|
if n.IgnoreNull {
|
|
ctx.WriteKeyWord(" IGNORE NULLS")
|
|
}
|
|
ctx.WriteKeyWord(" OVER ")
|
|
if err := n.Spec.Restore(ctx); err != nil {
|
|
return errors.Annotate(err, "An error occurred while restore WindowFuncExpr.Spec")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Format formats the window function expression into a Writer.
|
|
func (n *WindowFuncExpr) Format(w io.Writer) {
|
|
panic("Not implemented")
|
|
}
|
|
|
|
// Accept implements Node Accept interface.
|
|
func (n *WindowFuncExpr) Accept(v Visitor) (Node, bool) {
|
|
newNode, skipChildren := v.Enter(n)
|
|
if skipChildren {
|
|
return v.Leave(newNode)
|
|
}
|
|
n = newNode.(*WindowFuncExpr)
|
|
for i, val := range n.Args {
|
|
node, ok := val.Accept(v)
|
|
if !ok {
|
|
return n, false
|
|
}
|
|
n.Args[i] = node.(ExprNode)
|
|
}
|
|
node, ok := n.Spec.Accept(v)
|
|
if !ok {
|
|
return n, false
|
|
}
|
|
n.Spec = *node.(*WindowSpec)
|
|
return v.Leave(n)
|
|
}
|
|
|
|
// TimeUnitType is the type for time and timestamp units.
|
|
type TimeUnitType int
|
|
|
|
const (
|
|
// TimeUnitInvalid is a placeholder for an invalid time or timestamp unit
|
|
TimeUnitInvalid TimeUnitType = iota
|
|
// TimeUnitMicrosecond is the time or timestamp unit MICROSECOND.
|
|
TimeUnitMicrosecond
|
|
// TimeUnitSecond is the time or timestamp unit SECOND.
|
|
TimeUnitSecond
|
|
// TimeUnitMinute is the time or timestamp unit MINUTE.
|
|
TimeUnitMinute
|
|
// TimeUnitHour is the time or timestamp unit HOUR.
|
|
TimeUnitHour
|
|
// TimeUnitDay is the time or timestamp unit DAY.
|
|
TimeUnitDay
|
|
// TimeUnitWeek is the time or timestamp unit WEEK.
|
|
TimeUnitWeek
|
|
// TimeUnitMonth is the time or timestamp unit MONTH.
|
|
TimeUnitMonth
|
|
// TimeUnitQuarter is the time or timestamp unit QUARTER.
|
|
TimeUnitQuarter
|
|
// TimeUnitYear is the time or timestamp unit YEAR.
|
|
TimeUnitYear
|
|
// TimeUnitSecondMicrosecond is the time unit SECOND_MICROSECOND.
|
|
TimeUnitSecondMicrosecond
|
|
// TimeUnitMinuteMicrosecond is the time unit MINUTE_MICROSECOND.
|
|
TimeUnitMinuteMicrosecond
|
|
// TimeUnitMinuteSecond is the time unit MINUTE_SECOND.
|
|
TimeUnitMinuteSecond
|
|
// TimeUnitHourMicrosecond is the time unit HOUR_MICROSECOND.
|
|
TimeUnitHourMicrosecond
|
|
// TimeUnitHourSecond is the time unit HOUR_SECOND.
|
|
TimeUnitHourSecond
|
|
// TimeUnitHourMinute is the time unit HOUR_MINUTE.
|
|
TimeUnitHourMinute
|
|
// TimeUnitDayMicrosecond is the time unit DAY_MICROSECOND.
|
|
TimeUnitDayMicrosecond
|
|
// TimeUnitDaySecond is the time unit DAY_SECOND.
|
|
TimeUnitDaySecond
|
|
// TimeUnitDayMinute is the time unit DAY_MINUTE.
|
|
TimeUnitDayMinute
|
|
// TimeUnitDayHour is the time unit DAY_HOUR.
|
|
TimeUnitDayHour
|
|
// TimeUnitYearMonth is the time unit YEAR_MONTH.
|
|
TimeUnitYearMonth
|
|
)
|
|
|
|
// String implements fmt.Stringer interface.
|
|
func (unit TimeUnitType) String() string {
|
|
switch unit {
|
|
case TimeUnitMicrosecond:
|
|
return "MICROSECOND"
|
|
case TimeUnitSecond:
|
|
return "SECOND"
|
|
case TimeUnitMinute:
|
|
return "MINUTE"
|
|
case TimeUnitHour:
|
|
return "HOUR"
|
|
case TimeUnitDay:
|
|
return "DAY"
|
|
case TimeUnitWeek:
|
|
return "WEEK"
|
|
case TimeUnitMonth:
|
|
return "MONTH"
|
|
case TimeUnitQuarter:
|
|
return "QUARTER"
|
|
case TimeUnitYear:
|
|
return "YEAR"
|
|
case TimeUnitSecondMicrosecond:
|
|
return "SECOND_MICROSECOND"
|
|
case TimeUnitMinuteMicrosecond:
|
|
return "MINUTE_MICROSECOND"
|
|
case TimeUnitMinuteSecond:
|
|
return "MINUTE_SECOND"
|
|
case TimeUnitHourMicrosecond:
|
|
return "HOUR_MICROSECOND"
|
|
case TimeUnitHourSecond:
|
|
return "HOUR_SECOND"
|
|
case TimeUnitHourMinute:
|
|
return "HOUR_MINUTE"
|
|
case TimeUnitDayMicrosecond:
|
|
return "DAY_MICROSECOND"
|
|
case TimeUnitDaySecond:
|
|
return "DAY_SECOND"
|
|
case TimeUnitDayMinute:
|
|
return "DAY_MINUTE"
|
|
case TimeUnitDayHour:
|
|
return "DAY_HOUR"
|
|
case TimeUnitYearMonth:
|
|
return "YEAR_MONTH"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// Duration represented by this unit.
|
|
// Returns error if the time unit is not a fixed time interval (such as MONTH)
|
|
// or a composite unit (such as MINUTE_SECOND).
|
|
func (unit TimeUnitType) Duration() (time.Duration, error) {
|
|
switch unit {
|
|
case TimeUnitMicrosecond:
|
|
return time.Microsecond, nil
|
|
case TimeUnitSecond:
|
|
return time.Second, nil
|
|
case TimeUnitMinute:
|
|
return time.Minute, nil
|
|
case TimeUnitHour:
|
|
return time.Hour, nil
|
|
case TimeUnitDay:
|
|
return time.Hour * 24, nil
|
|
case TimeUnitWeek:
|
|
return time.Hour * 24 * 7, nil
|
|
case TimeUnitMonth, TimeUnitQuarter, TimeUnitYear:
|
|
return 0, errors.Errorf("%s is not a constant time interval and cannot be used here", unit)
|
|
default:
|
|
return 0, errors.Errorf("%s is a composite time unit and is not supported yet", unit)
|
|
}
|
|
}
|
|
|
|
// TimeUnitExpr is an expression representing a time or timestamp unit.
|
|
type TimeUnitExpr struct {
|
|
exprNode
|
|
// Unit is the time or timestamp unit.
|
|
Unit TimeUnitType
|
|
}
|
|
|
|
// Restore implements Node interface.
|
|
func (n *TimeUnitExpr) Restore(ctx *format.RestoreCtx) error {
|
|
ctx.WriteKeyWord(n.Unit.String())
|
|
return nil
|
|
}
|
|
|
|
// Format the ExprNode into a Writer.
|
|
func (n *TimeUnitExpr) Format(w io.Writer) {
|
|
fmt.Fprint(w, n.Unit.String())
|
|
}
|
|
|
|
// Accept implements Node Accept interface.
|
|
func (n *TimeUnitExpr) Accept(v Visitor) (Node, bool) {
|
|
newNode, skipChildren := v.Enter(n)
|
|
if skipChildren {
|
|
return v.Leave(newNode)
|
|
}
|
|
return v.Leave(n)
|
|
}
|
|
|
|
// GetFormatSelectorType is the type for the first argument of GET_FORMAT() function.
|
|
type GetFormatSelectorType int
|
|
|
|
const (
|
|
// GetFormatSelectorDate is the GET_FORMAT selector DATE.
|
|
GetFormatSelectorDate GetFormatSelectorType = iota + 1
|
|
// GetFormatSelectorTime is the GET_FORMAT selector TIME.
|
|
GetFormatSelectorTime
|
|
// GetFormatSelectorDatetime is the GET_FORMAT selector DATETIME and TIMESTAMP.
|
|
GetFormatSelectorDatetime
|
|
)
|
|
|
|
// GetFormatSelectorExpr is an expression used as the first argument of GET_FORMAT() function.
|
|
type GetFormatSelectorExpr struct {
|
|
exprNode
|
|
// Selector is the GET_FORMAT() selector.
|
|
Selector GetFormatSelectorType
|
|
}
|
|
|
|
// String implements fmt.Stringer interface.
|
|
func (selector GetFormatSelectorType) String() string {
|
|
switch selector {
|
|
case GetFormatSelectorDate:
|
|
return "DATE"
|
|
case GetFormatSelectorTime:
|
|
return "TIME"
|
|
case GetFormatSelectorDatetime:
|
|
return "DATETIME"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// Restore implements Node interface.
|
|
func (n *GetFormatSelectorExpr) Restore(ctx *format.RestoreCtx) error {
|
|
ctx.WriteKeyWord(n.Selector.String())
|
|
return nil
|
|
}
|
|
|
|
// Format the ExprNode into a Writer.
|
|
func (n *GetFormatSelectorExpr) Format(w io.Writer) {
|
|
fmt.Fprint(w, n.Selector.String())
|
|
}
|
|
|
|
// Accept implements Node Accept interface.
|
|
func (n *GetFormatSelectorExpr) Accept(v Visitor) (Node, bool) {
|
|
newNode, skipChildren := v.Enter(n)
|
|
if skipChildren {
|
|
return v.Leave(newNode)
|
|
}
|
|
return v.Leave(n)
|
|
}
|