Files
tidb/pkg/expression/constant.go

674 lines
19 KiB
Go

// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"fmt"
"slices"
"unsafe"
perrors "github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/planner/cascades/base"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/codec"
"github.com/pingcap/tidb/pkg/util/collate"
"github.com/pingcap/tidb/pkg/util/intest"
"github.com/pingcap/tidb/pkg/util/logutil"
"go.uber.org/zap"
)
var _ base.HashEquals = &Constant{}
// NewOne stands for an unsigned number 1.
func NewOne() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
retT.AddFlag(mysql.UnsignedFlag) // shrink range to avoid integral promotion
retT.SetFlen(1)
retT.SetDecimal(0)
return &Constant{
Value: types.NewDatum(1),
RetType: retT,
}
}
// NewSignedOne stands for a signed number 1.
func NewSignedOne() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
retT.SetFlen(1)
retT.SetDecimal(0)
return &Constant{
Value: types.NewDatum(1),
RetType: retT,
}
}
// NewZero stands for an unsigned number 0.
func NewZero() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
retT.AddFlag(mysql.UnsignedFlag) // shrink range to avoid integral promotion
retT.SetFlen(1)
retT.SetDecimal(0)
return &Constant{
Value: types.NewDatum(0),
RetType: retT,
}
}
// NewSignedZero stands for a signed number 0.
func NewSignedZero() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
retT.SetFlen(1)
retT.SetDecimal(0)
return &Constant{
Value: types.NewDatum(0),
RetType: retT,
}
}
// NewUInt64Const stands for constant of a given number.
func NewUInt64Const(num int) *Constant {
retT := types.NewFieldType(mysql.TypeLonglong)
retT.AddFlag(mysql.UnsignedFlag) // shrink range to avoid integral promotion
retT.SetFlen(mysql.MaxIntWidth)
retT.SetDecimal(0)
return &Constant{
Value: types.NewDatum(num),
RetType: retT,
}
}
// NewUInt64ConstWithFieldType stands for constant of a given number with specified fieldType.
func NewUInt64ConstWithFieldType(num uint64, fieldType *types.FieldType) *Constant {
return &Constant{
Value: types.NewDatum(num),
RetType: fieldType,
}
}
// NewInt64Const stands for constant of a given number.
func NewInt64Const(num int64) *Constant {
retT := types.NewFieldType(mysql.TypeLonglong)
retT.SetFlen(mysql.MaxIntWidth)
retT.SetDecimal(0)
return &Constant{
Value: types.NewDatum(num),
RetType: retT,
}
}
// NewStrConst stands for constant of a given string.
// used in test only now.
func NewStrConst(str string) *Constant {
retT := types.NewFieldType(mysql.TypeVarString)
retT.SetFlen(len(str))
return &Constant{
Value: types.NewDatum(str),
RetType: retT,
}
}
// NewNull stands for null constant.
func NewNull() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
retT.SetFlen(1)
retT.SetDecimal(0)
return &Constant{
Value: types.NewDatum(nil),
RetType: retT,
}
}
// NewNullWithFieldType stands for null constant with specified fieldType.
func NewNullWithFieldType(fieldType *types.FieldType) *Constant {
return &Constant{
Value: types.NewDatum(nil),
RetType: fieldType,
}
}
// Constant stands for a constant value.
type Constant struct {
Value types.Datum
RetType *types.FieldType `plan-cache-clone:"shallow"`
// DeferredExpr holds deferred function in PlanCache cached plan.
// it's only used to represent non-deterministic functions(see expression.DeferredFunctions)
// in PlanCache cached plan, so let them can be evaluated until cached item be used.
DeferredExpr Expression
// ParamMarker holds param index inside sessionVars.PreparedParams.
// It's only used to reference a user variable provided in the `EXECUTE` statement or `COM_EXECUTE` binary protocol.
ParamMarker *ParamMarker
hashcode []byte
// SubqueryRefID holds the ID of the original subquery column for display purposes
SubqueryRefID int64
collationInfo
}
// ParamMarker indicates param provided by COM_STMT_EXECUTE.
type ParamMarker struct {
order int
}
// SafeToShareAcrossSession returns if the function can be shared across different sessions.
func (c *Constant) SafeToShareAcrossSession() bool {
if c.DeferredExpr != nil {
return c.DeferredExpr.SafeToShareAcrossSession()
}
return true
}
// GetUserVar returns the corresponding user variable presented in the `EXECUTE` statement or `COM_EXECUTE` command.
func (d *ParamMarker) GetUserVar(ctx ParamValues) (types.Datum, error) {
return ctx.GetParamValue(d.order)
}
// StringWithCtx implements Expression interface.
func (c *Constant) StringWithCtx(ctx ParamValues, redact string) string {
v := c.Value
if c.ParamMarker != nil {
dt, err := c.ParamMarker.GetUserVar(ctx)
intest.AssertNoError(err, "fail to get param")
if err != nil {
return "?"
}
v = dt
} else if c.DeferredExpr != nil {
return c.DeferredExpr.StringWithCtx(ctx, redact)
}
var valueStr string
switch redact {
case perrors.RedactLogDisable:
valueStr = v.TruncatedStringify()
case perrors.RedactLogMarker:
valueStr = fmt.Sprintf("‹%s›", v.TruncatedStringify())
default:
valueStr = "?"
}
// Add subquery reference if available
if c.SubqueryRefID > 0 {
refStr := fmt.Sprintf("ScalarQueryCol#%d", c.SubqueryRefID)
return fmt.Sprintf("%s(%s)", refStr, valueStr)
}
return valueStr
}
// Clone implements Expression interface.
func (c *Constant) Clone() Expression {
con := *c
con.RetType = c.RetType.Clone()
if c.ParamMarker != nil {
con.ParamMarker = &ParamMarker{order: c.ParamMarker.order}
}
if c.DeferredExpr != nil {
con.DeferredExpr = c.DeferredExpr.Clone()
}
if c.hashcode != nil {
con.hashcode = slices.Clone(c.hashcode)
}
return &con
}
// GetType implements Expression interface.
func (c *Constant) GetType(ctx EvalContext) *types.FieldType {
if c.ParamMarker != nil {
// GetType() may be called in multi-threaded context, e.g, in building inner executors of IndexJoin,
// so it should avoid data race. We achieve this by returning different FieldType pointer for each call.
tp := types.NewFieldType(mysql.TypeUnspecified)
dt, err := c.ParamMarker.GetUserVar(ctx)
intest.AssertNoError(err, "fail to get param")
if err != nil {
logutil.BgLogger().Warn("fail to get param", zap.Error(err))
return nil
}
types.InferParamTypeFromDatum(&dt, tp)
return tp
}
return c.RetType
}
// VecEvalInt evaluates this expression in a vectorized manner.
func (c *Constant) VecEvalInt(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
if c.DeferredExpr == nil {
return genVecFromConstExpr(ctx, c, types.ETInt, input, result)
}
return c.DeferredExpr.VecEvalInt(ctx, input, result)
}
// VecEvalReal evaluates this expression in a vectorized manner.
func (c *Constant) VecEvalReal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
if c.DeferredExpr == nil {
return genVecFromConstExpr(ctx, c, types.ETReal, input, result)
}
return c.DeferredExpr.VecEvalReal(ctx, input, result)
}
// VecEvalString evaluates this expression in a vectorized manner.
func (c *Constant) VecEvalString(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
if c.DeferredExpr == nil {
return genVecFromConstExpr(ctx, c, types.ETString, input, result)
}
return c.DeferredExpr.VecEvalString(ctx, input, result)
}
// VecEvalDecimal evaluates this expression in a vectorized manner.
func (c *Constant) VecEvalDecimal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
if c.DeferredExpr == nil {
return genVecFromConstExpr(ctx, c, types.ETDecimal, input, result)
}
return c.DeferredExpr.VecEvalDecimal(ctx, input, result)
}
// VecEvalTime evaluates this expression in a vectorized manner.
func (c *Constant) VecEvalTime(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
if c.DeferredExpr == nil {
return genVecFromConstExpr(ctx, c, types.ETTimestamp, input, result)
}
return c.DeferredExpr.VecEvalTime(ctx, input, result)
}
// VecEvalDuration evaluates this expression in a vectorized manner.
func (c *Constant) VecEvalDuration(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
if c.DeferredExpr == nil {
return genVecFromConstExpr(ctx, c, types.ETDuration, input, result)
}
return c.DeferredExpr.VecEvalDuration(ctx, input, result)
}
// VecEvalJSON evaluates this expression in a vectorized manner.
func (c *Constant) VecEvalJSON(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
if c.DeferredExpr == nil {
return genVecFromConstExpr(ctx, c, types.ETJson, input, result)
}
return c.DeferredExpr.VecEvalJSON(ctx, input, result)
}
// VecEvalVectorFloat32 evaluates this expression in a vectorized manner.
func (c *Constant) VecEvalVectorFloat32(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
if c.DeferredExpr == nil {
return genVecFromConstExpr(ctx, c, types.ETVectorFloat32, input, result)
}
return c.DeferredExpr.VecEvalVectorFloat32(ctx, input, result)
}
func (c *Constant) getLazyDatum(ctx EvalContext, row chunk.Row) (dt types.Datum, isLazy bool, err error) {
if c.ParamMarker != nil {
val, err := c.ParamMarker.GetUserVar(ctx)
intest.AssertNoError(err, "fail to get param")
if err != nil {
return val, true, err
}
return val, true, nil
} else if c.DeferredExpr != nil {
dt, err = c.DeferredExpr.Eval(ctx, row)
return dt, true, err
}
return types.Datum{}, false, nil
}
// Traverse implements the TraverseDown interface.
func (c *Constant) Traverse(action TraverseAction) Expression {
return action.Transform(c)
}
// Eval implements Expression interface.
func (c *Constant) Eval(ctx EvalContext, row chunk.Row) (types.Datum, error) {
intest.AssertNotNil(ctx)
if dt, lazy, err := c.getLazyDatum(ctx, row); lazy {
if err != nil {
return c.Value, err
}
if dt.IsNull() {
return dt, nil
}
if c.DeferredExpr != nil {
if dt.Kind() != types.KindMysqlDecimal {
val, err := dt.ConvertTo(typeCtx(ctx), c.RetType)
if err != nil {
return dt, err
}
return val, nil
}
if err := c.adjustDecimal(ctx, dt.GetMysqlDecimal()); err != nil {
return dt, err
}
}
return dt, nil
}
return c.Value, nil
}
// EvalInt returns int representation of Constant.
func (c *Constant) EvalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) {
dt, lazy, err := c.getLazyDatum(ctx, row)
if err != nil {
return 0, false, err
}
if !lazy {
dt = c.Value
}
if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() {
return 0, true, nil
} else if dt.Kind() == types.KindBinaryLiteral {
val, err := dt.GetBinaryLiteral().ToInt(typeCtx(ctx))
return int64(val), err != nil, err
} else if c.GetType(ctx).Hybrid() || dt.Kind() == types.KindString {
res, err := dt.ToInt64(typeCtx(ctx))
return res, false, err
} else if dt.Kind() == types.KindMysqlBit {
uintVal, err := dt.GetBinaryLiteral().ToInt(typeCtx(ctx))
return int64(uintVal), false, err
}
return dt.GetInt64(), false, nil
}
// EvalReal returns real representation of Constant.
func (c *Constant) EvalReal(ctx EvalContext, row chunk.Row) (float64, bool, error) {
dt, lazy, err := c.getLazyDatum(ctx, row)
if err != nil {
return 0, false, err
}
if !lazy {
dt = c.Value
}
if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() {
return 0, true, nil
}
if c.GetType(ctx).Hybrid() || dt.Kind() == types.KindBinaryLiteral || dt.Kind() == types.KindString {
res, err := dt.ToFloat64(typeCtx(ctx))
return res, false, err
}
return dt.GetFloat64(), false, nil
}
// EvalString returns string representation of Constant.
func (c *Constant) EvalString(ctx EvalContext, row chunk.Row) (string, bool, error) {
dt, lazy, err := c.getLazyDatum(ctx, row)
if err != nil {
return "", false, err
}
if !lazy {
dt = c.Value
}
if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() {
return "", true, nil
}
res, err := dt.ToString()
return res, false, err
}
// EvalDecimal returns decimal representation of Constant.
func (c *Constant) EvalDecimal(ctx EvalContext, row chunk.Row) (*types.MyDecimal, bool, error) {
dt, lazy, err := c.getLazyDatum(ctx, row)
if err != nil {
return nil, false, err
}
if !lazy {
dt = c.Value
}
if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() {
return nil, true, nil
}
res, err := dt.ToDecimal(typeCtx(ctx))
if err != nil {
return nil, false, err
}
if err := c.adjustDecimal(ctx, res); err != nil {
return nil, false, err
}
return res, false, nil
}
func (c *Constant) adjustDecimal(ctx EvalContext, d *types.MyDecimal) error {
// Decimal Value's precision and frac may be modified during plan building.
_, frac := d.PrecisionAndFrac()
if frac < c.GetType(ctx).GetDecimal() {
return d.Round(d, c.GetType(ctx).GetDecimal(), types.ModeHalfUp)
}
return nil
}
// EvalTime returns DATE/DATETIME/TIMESTAMP representation of Constant.
func (c *Constant) EvalTime(ctx EvalContext, row chunk.Row) (val types.Time, isNull bool, err error) {
dt, lazy, err := c.getLazyDatum(ctx, row)
if err != nil {
return types.ZeroTime, false, err
}
if !lazy {
dt = c.Value
}
if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() {
return types.ZeroTime, true, nil
}
return dt.GetMysqlTime(), false, nil
}
// EvalDuration returns Duration representation of Constant.
func (c *Constant) EvalDuration(ctx EvalContext, row chunk.Row) (val types.Duration, isNull bool, err error) {
dt, lazy, err := c.getLazyDatum(ctx, row)
if err != nil {
return types.Duration{}, false, err
}
if !lazy {
dt = c.Value
}
if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() {
return types.Duration{}, true, nil
}
return dt.GetMysqlDuration(), false, nil
}
// EvalJSON returns JSON representation of Constant.
func (c *Constant) EvalJSON(ctx EvalContext, row chunk.Row) (types.BinaryJSON, bool, error) {
dt, lazy, err := c.getLazyDatum(ctx, row)
if err != nil {
return types.BinaryJSON{}, false, err
}
if !lazy {
dt = c.Value
}
if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() {
return types.BinaryJSON{}, true, nil
}
return dt.GetMysqlJSON(), false, nil
}
// EvalVectorFloat32 returns VectorFloat32 representation of Constant.
func (c *Constant) EvalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) {
dt, lazy, err := c.getLazyDatum(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, false, err
}
if !lazy {
dt = c.Value
}
if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() {
return types.ZeroVectorFloat32, true, nil
}
return dt.GetVectorFloat32(), false, nil
}
// Equal implements Expression interface.
func (c *Constant) Equal(ctx EvalContext, b Expression) bool {
y, ok := b.(*Constant)
if !ok {
return false
}
_, err1 := y.Eval(ctx, chunk.Row{})
_, err2 := c.Eval(ctx, chunk.Row{})
if err1 != nil || err2 != nil {
return false
}
con, err := c.Value.Compare(typeCtx(ctx), &y.Value, collate.GetBinaryCollator())
if err != nil || con != 0 {
return false
}
return true
}
// IsCorrelated implements Expression interface.
func (c *Constant) IsCorrelated() bool {
return false
}
// ConstLevel returns the const level for the expression
func (c *Constant) ConstLevel() ConstLevel {
if c.DeferredExpr != nil || c.ParamMarker != nil {
return ConstOnlyInContext
}
return ConstStrict
}
// Decorrelate implements Expression interface.
func (c *Constant) Decorrelate(_ *Schema) Expression {
return c
}
// HashCode implements Expression interface.
func (c *Constant) HashCode() []byte {
return c.getHashCode(false)
}
// CanonicalHashCode implements Expression interface.
func (c *Constant) CanonicalHashCode() []byte {
return c.getHashCode(true)
}
// Hash64 implements HashEquals.<0th> interface.
func (c *Constant) Hash64(h base.Hasher) {
if c.RetType == nil {
h.HashByte(base.NilFlag)
} else {
h.HashByte(base.NotNilFlag)
c.RetType.Hash64(h)
}
c.collationInfo.Hash64(h)
if c.DeferredExpr != nil {
c.DeferredExpr.Hash64(h)
return
}
if c.ParamMarker != nil {
h.HashByte(parameterFlag)
h.HashInt64(int64(c.ParamMarker.order))
return
}
intest.Assert(c.DeferredExpr == nil && c.ParamMarker == nil)
h.HashByte(constantFlag)
c.Value.Hash64(h)
}
// Equals implements HashEquals.<1st> interface.
func (c *Constant) Equals(other any) bool {
c2, ok := other.(*Constant)
if !ok {
return false
}
if c == nil {
return c2 == nil
}
if c2 == nil {
return false
}
ok = c.RetType == nil && c2.RetType == nil || c.RetType != nil && c2.RetType != nil && c.RetType.Equals(c2.RetType)
ok = ok && c.collationInfo.Equals(c2.collationInfo)
ok = ok && (c.DeferredExpr == nil && c2.DeferredExpr == nil || c.DeferredExpr != nil && c2.DeferredExpr != nil && c.DeferredExpr.Equals(c2.DeferredExpr))
ok = ok && (c.ParamMarker == nil && c2.ParamMarker == nil || c.ParamMarker != nil && c2.ParamMarker != nil && c.ParamMarker.order == c2.ParamMarker.order)
return ok && c.Value.Equals(c2.Value)
}
func (c *Constant) getHashCode(canonical bool) []byte {
if len(c.hashcode) > 0 {
return c.hashcode
}
if c.DeferredExpr != nil {
if canonical {
c.hashcode = c.DeferredExpr.CanonicalHashCode()
} else {
c.hashcode = c.DeferredExpr.HashCode()
}
return c.hashcode
}
if c.ParamMarker != nil {
c.hashcode = append(c.hashcode, parameterFlag)
c.hashcode = codec.EncodeInt(c.hashcode, int64(c.ParamMarker.order))
return c.hashcode
}
intest.Assert(c.DeferredExpr == nil && c.ParamMarker == nil)
c.hashcode = append(c.hashcode, constantFlag)
c.hashcode = codec.HashCode(c.hashcode, c.Value)
return c.hashcode
}
// ResolveIndices implements Expression interface.
func (c *Constant) ResolveIndices(_ *Schema) (Expression, error) {
return c, nil
}
func (c *Constant) resolveIndices(_ *Schema) error {
return nil
}
// ResolveIndicesByVirtualExpr implements Expression interface.
func (c *Constant) ResolveIndicesByVirtualExpr(_ EvalContext, _ *Schema) (Expression, bool) {
return c, true
}
func (c *Constant) resolveIndicesByVirtualExpr(_ EvalContext, _ *Schema) bool {
return true
}
// RemapColumn remaps columns with provided mapping and returns new expression
func (c *Constant) RemapColumn(_ map[int64]*Column) (Expression, error) {
return c, nil
}
// Vectorized returns if this expression supports vectorized evaluation.
func (c *Constant) Vectorized() bool {
if c.DeferredExpr != nil {
return c.DeferredExpr.Vectorized()
}
return true
}
// Coercibility returns the coercibility value which is used to check collations.
func (c *Constant) Coercibility() Coercibility {
if !c.HasCoercibility() {
c.SetCoercibility(deriveCoercibilityForConstant(c))
}
return c.collationInfo.Coercibility()
}
const emptyConstantSize = int64(unsafe.Sizeof(Constant{}))
// MemoryUsage return the memory usage of Constant
func (c *Constant) MemoryUsage() (sum int64) {
if c == nil {
return
}
sum = emptyConstantSize + c.Value.MemUsage() + int64(cap(c.hashcode))
if c.RetType != nil {
sum += c.RetType.MemoryUsage()
}
return
}