Files
tidb/pkg/planner/core/expression_rewriter.go

2758 lines
102 KiB
Go

// Copyright 2016 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package core
import (
"context"
"fmt"
"slices"
"strconv"
"strings"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/expression/aggregation"
"github.com/pingcap/tidb/pkg/expression/exprctx"
"github.com/pingcap/tidb/pkg/expression/expropt"
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/meta/model"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/opcode"
"github.com/pingcap/tidb/pkg/planner/core/base"
"github.com/pingcap/tidb/pkg/planner/core/operator/logicalop"
"github.com/pingcap/tidb/pkg/planner/core/operator/physicalop"
"github.com/pingcap/tidb/pkg/planner/core/rule"
"github.com/pingcap/tidb/pkg/planner/util/coreusage"
"github.com/pingcap/tidb/pkg/sessionctx/vardef"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/table"
"github.com/pingcap/tidb/pkg/types"
driver "github.com/pingcap/tidb/pkg/types/parser_driver"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/collate"
"github.com/pingcap/tidb/pkg/util/dbterror/plannererrors"
"github.com/pingcap/tidb/pkg/util/hint"
"github.com/pingcap/tidb/pkg/util/intest"
sem "github.com/pingcap/tidb/pkg/util/sem/compat"
"github.com/pingcap/tidb/pkg/util/stringutil"
)
// EvalSubqueryFirstRow evaluates incorrelated subqueries once, and get first row.
var EvalSubqueryFirstRow func(ctx context.Context, p base.PhysicalPlan, is infoschema.InfoSchema, sctx base.PlanContext) (row []types.Datum, err error)
// evalAstExprWithPlanCtx evaluates ast expression with plan context.
// Different with expression.EvalSimpleAst, it uses planner context and is more powerful to build some special expressions
// like subquery, window function, etc.
func evalAstExprWithPlanCtx(sctx base.PlanContext, expr ast.ExprNode) (types.Datum, error) {
if val, ok := expr.(*driver.ValueExpr); ok {
return val.Datum, nil
}
newExpr, err := rewriteAstExprWithPlanCtx(sctx, expr, nil, nil, false)
if err != nil {
return types.Datum{}, err
}
return newExpr.Eval(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{})
}
// evalAstExpr evaluates ast expression directly.
func evalAstExpr(ctx expression.BuildContext, expr ast.ExprNode) (types.Datum, error) {
if val, ok := expr.(*driver.ValueExpr); ok {
return val.Datum, nil
}
newExpr, err := buildSimpleExpr(ctx, expr)
if err != nil {
return types.Datum{}, err
}
return newExpr.Eval(ctx.GetEvalCtx(), chunk.Row{})
}
// rewriteAstExprWithPlanCtx rewrites ast expression directly.
// Different with expression.BuildSimpleExpr, it uses planner context and is more powerful to build some special expressions
// like subquery, window function, etc.
func rewriteAstExprWithPlanCtx(sctx base.PlanContext, expr ast.ExprNode, schema *expression.Schema, names types.NameSlice, allowCastArray bool) (expression.Expression, error) {
var is infoschema.InfoSchema
// in tests, it may be null
if s, ok := sctx.GetInfoSchema().(infoschema.InfoSchema); ok {
is = s
}
b, savedBlockNames := NewPlanBuilder().Init(sctx, is, hint.NewQBHintHandler(nil))
b.allowBuildCastArray = allowCastArray
fakePlan := logicalop.LogicalTableDual{}.Init(sctx, 0)
if schema != nil {
fakePlan.SetSchema(schema)
fakePlan.SetOutputNames(names)
}
b.curClause = expressionClause
newExpr, _, err := b.rewrite(context.TODO(), expr, fakePlan, nil, true)
if err != nil {
return nil, err
}
sctx.GetSessionVars().PlannerSelectBlockAsName.Store(&savedBlockNames)
return newExpr, nil
}
func buildSimpleExpr(ctx expression.BuildContext, node ast.ExprNode, opts ...expression.BuildOption) (expression.Expression, error) {
intest.AssertNotNil(node)
if node == nil {
// This should never happen. Return error to make it easy to debug in case we have some unexpected bugs.
return nil, errors.New("expression node should be present")
}
var options expression.BuildOptions
for _, opt := range opts {
opt(&options)
}
if options.InputSchema == nil && len(options.InputNames) > 0 {
return nil, errors.New("InputSchema and InputNames should be specified at the same time")
}
if options.InputSchema != nil && len(options.InputSchema.Columns) != len(options.InputNames) {
return nil, errors.New("InputSchema and InputNames should be the same length")
}
// assert all input db names are the same if specified
intest.AssertFunc(func() bool {
if len(options.InputNames) == 0 {
return true
}
dbName := options.InputNames[0].DBName
if options.SourceTableDB.L != "" {
intest.Assert(dbName.L == options.SourceTableDB.L)
}
for _, name := range options.InputNames {
intest.Assert(name.DBName.L == dbName.L)
}
return true
})
rewriter := &expressionRewriter{
ctx: context.TODO(),
sctx: ctx,
schema: options.InputSchema,
names: options.InputNames,
sourceTable: options.SourceTable,
allowBuildCastArray: options.AllowCastArray,
asScalar: true,
}
if tbl := options.SourceTable; tbl != nil && rewriter.schema == nil {
cols, names, err := expression.ColumnInfos2ColumnsAndNames(ctx, options.SourceTableDB, tbl.Name, tbl.Cols(), tbl)
if err != nil {
return nil, err
}
intest.Assert(len(cols) == len(names))
rewriter.schema = expression.NewSchema(cols...)
rewriter.names = names
}
if rewriter.schema == nil {
rewriter.schema = expression.NewSchema()
}
expr, _, err := rewriteExprNode(rewriter, node, rewriter.asScalar)
if err != nil {
return nil, err
}
if ft := options.TargetFieldType; ft != nil {
expr = expression.BuildCastFunction(ctx, expr, ft)
}
return expr, err
}
func (b *PlanBuilder) rewriteInsertOnDuplicateUpdate(ctx context.Context, exprNode ast.ExprNode, mockPlan base.LogicalPlan, insertPlan *physicalop.Insert) (expression.Expression, error) {
b.rewriterCounter++
defer func() { b.rewriterCounter-- }()
b.curClause = fieldList
rewriter := b.getExpressionRewriter(ctx, mockPlan)
// The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is
// not nil means certain previous procedure has not handled this error.
// Here we give us one more chance to make a correct behavior by handling
// this missed error.
if rewriter.err != nil {
return nil, rewriter.err
}
rewriter.planCtx.insertPlan = insertPlan
rewriter.asScalar = true
rewriter.allowBuildCastArray = b.allowBuildCastArray
expr, _, err := rewriteExprNode(rewriter, exprNode, true)
return expr, err
}
// rewrite function rewrites ast expr to expression.Expression.
// aggMapper maps ast.AggregateFuncExpr to the columns offset in p's output schema.
// asScalar means whether this expression must be treated as a scalar expression.
// And this function returns a result expression, a new plan that may have apply or semi-join.
func (b *PlanBuilder) rewrite(ctx context.Context, exprNode ast.ExprNode, p base.LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, asScalar bool) (expression.Expression, base.LogicalPlan, error) {
expr, resultPlan, err := b.rewriteWithPreprocess(ctx, exprNode, p, aggMapper, nil, asScalar, nil)
return expr, resultPlan, err
}
// rewriteWithPreprocess is for handling the situation that we need to adjust the input ast tree
// before really using its node in `expressionRewriter.Leave`. In that case, we first call
// er.preprocess(expr), which returns a new expr. Then we use the new expr in `Leave`.
func (b *PlanBuilder) rewriteWithPreprocess(
ctx context.Context,
exprNode ast.ExprNode,
p base.LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int,
windowMapper map[*ast.WindowFuncExpr]int,
asScalar bool,
preprocess func(ast.Node) ast.Node,
) (expression.Expression, base.LogicalPlan, error) {
b.rewriterCounter++
defer func() { b.rewriterCounter-- }()
rewriter := b.getExpressionRewriter(ctx, p)
// The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is
// not nil means certain previous procedure has not handled this error.
// Here we give us one more chance to make a correct behavior by handling
// this missed error.
if rewriter.err != nil {
return nil, nil, rewriter.err
}
rewriter.planCtx.aggrMap = aggMapper
rewriter.planCtx.windowMap = windowMapper
rewriter.asScalar = asScalar
rewriter.allowBuildCastArray = b.allowBuildCastArray
rewriter.preprocess = preprocess
expr, resultPlan, err := rewriteExprNode(rewriter, exprNode, asScalar)
return expr, resultPlan, err
}
func (b *PlanBuilder) getExpressionRewriter(ctx context.Context, p base.LogicalPlan) (rewriter *expressionRewriter) {
defer func() {
if p != nil {
if join, ok := p.(*logicalop.LogicalJoin); ok && join.FullSchema != nil {
rewriter.schema = join.FullSchema
rewriter.names = join.FullNames
} else {
rewriter.schema = p.Schema()
rewriter.names = p.OutputNames()
}
}
}()
if len(b.rewriterPool) < b.rewriterCounter {
rewriter = &expressionRewriter{
sctx: b.ctx.GetExprCtx(), ctx: ctx,
planCtx: &exprRewriterPlanCtx{plan: p, builder: b, curClause: b.curClause, rollExpand: b.currentBlockExpand},
}
b.rewriterPool = append(b.rewriterPool, rewriter)
return
}
rewriter = b.rewriterPool[b.rewriterCounter-1]
rewriter.asScalar = false
rewriter.preprocess = nil
rewriter.disableFoldCounter = 0
rewriter.tryFoldCounter = 0
rewriter.ctxStack = rewriter.ctxStack[:0]
rewriter.ctxNameStk = rewriter.ctxNameStk[:0]
rewriter.ctx = ctx
rewriter.err = nil
rewriter.planCtx.plan = p
rewriter.planCtx.curClause = b.curClause
rewriter.planCtx.aggrMap = nil
rewriter.planCtx.insertPlan = nil
rewriter.planCtx.rollExpand = b.currentBlockExpand
return
}
func rewriteExprNode(rewriter *expressionRewriter, exprNode ast.ExprNode, asScalar bool) (expression.Expression, base.LogicalPlan, error) {
planCtx := rewriter.planCtx
// sourceTable is only used to build simple expression with one table
// when planCtx is present, sourceTable should be nil.
intest.Assert(planCtx == nil || rewriter.sourceTable == nil)
if planCtx != nil && planCtx.plan != nil {
curColLen := planCtx.plan.Schema().Len()
defer func() {
names := planCtx.plan.OutputNames().Shallow()[:curColLen]
for i := curColLen; i < planCtx.plan.Schema().Len(); i++ {
names = append(names, types.EmptyName)
}
// After rewriting finished, only old columns are visible.
// e.g. select * from t where t.a in (select t1.a from t1);
// The output columns before we enter the subquery are the columns from t.
// But when we leave the subquery `t.a in (select t1.a from t1)`, we got a Apply operator
// and the output columns become [t.*, t1.*]. But t1.* is used only inside the subquery. If there's another filter
// which is also a subquery where t1 is involved. The name resolving will fail if we still expose the column from
// the previous subquery.
// So here we just reset the names to empty to avoid this situation.
// TODO: implement ScalarSubQuery and resolve it during optimizing. In building phase, we will not change the plan's structure.
planCtx.plan.SetOutputNames(names)
}()
}
exprNode.Accept(rewriter)
if rewriter.err != nil {
return nil, nil, errors.Trace(rewriter.err)
}
var plan base.LogicalPlan
if planCtx != nil {
plan = planCtx.plan
}
if !asScalar && len(rewriter.ctxStack) == 0 {
return nil, plan, nil
}
if len(rewriter.ctxStack) != 1 {
return nil, nil, errors.Errorf("context len %v is invalid", len(rewriter.ctxStack))
}
rewriter.err = expression.CheckArgsNotMultiColumnRow(rewriter.ctxStack[0])
if rewriter.err != nil {
return nil, nil, errors.Trace(rewriter.err)
}
return rewriter.ctxStack[0], plan, nil
}
type exprRewriterPlanCtx struct {
plan base.LogicalPlan
builder *PlanBuilder
// curClause tracks which part of the query is being processed
curClause clauseCode
aggrMap map[*ast.AggregateFuncExpr]int
windowMap map[*ast.WindowFuncExpr]int
// insertPlan is only used to rewrite the expressions inside the assignment
// of the "INSERT" statement.
insertPlan *physicalop.Insert
rollExpand *logicalop.LogicalExpand
}
type expressionRewriter struct {
ctxStack []expression.Expression
ctxNameStk []*types.FieldName
schema *expression.Schema
names []*types.FieldName
err error
sctx expression.BuildContext
ctx context.Context
// asScalar indicates the return value must be a scalar value.
// NOTE: This value can be changed during expression rewritten.
asScalar bool
// allowBuildCastArray indicates whether allow cast(... as ... array).
allowBuildCastArray bool
// sourceTable is only used to build simple expression without all columns from a single table
sourceTable *model.TableInfo
// preprocess is called for every ast.Node in Leave.
preprocess func(ast.Node) ast.Node
// disableFoldCounter controls fold-disabled scope. If > 0, rewriter will NOT do constant folding.
// Typically, during visiting AST, while entering the scope(disable), the counter will +1; while
// leaving the scope(enable again), the counter will -1.
// NOTE: This value can be changed during expression rewritten.
disableFoldCounter int
tryFoldCounter int
planCtx *exprRewriterPlanCtx
}
func (er *expressionRewriter) ctxStackLen() int {
return len(er.ctxStack)
}
func (er *expressionRewriter) ctxStackPop(num int) {
l := er.ctxStackLen()
er.ctxStack = er.ctxStack[:l-num]
er.ctxNameStk = er.ctxNameStk[:l-num]
}
func (er *expressionRewriter) ctxStackAppend(col expression.Expression, name *types.FieldName) {
er.ctxStack = append(er.ctxStack, col)
er.ctxNameStk = append(er.ctxNameStk, name)
}
// constructBinaryOpFunction converts binary operator functions
/*
The algorithm is as follows:
1. If the length of the two sides of the expression is 1, return l op r directly.
2. If the length of the two sides of the expression is not equal, return an error.
3. If the operator is EQ, NE, or NullEQ, converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2)
4. If the operator is not EQ, NE, or NullEQ,
converts (a0,a1,a2) op (b0,b1,b2) to (a0 > b0) or (a0 = b0 and a1 > b1) or (a0 = b0 and a1 = b1 and a2 op b2)
Especially, op is GE or LE, the prefix element will be converted to > or <.
converts (a0,a1,a2) >= (b0,b1,b2) to (a0 > b0) or (a0 = b0 and a1 > b1) or (a0 = b0 and a1 = b1 and a2 >= b2)
The only different between >= and > is that >= additional include the (x,y,z) = (a,b,c).
*/
func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) {
lLen, rLen := expression.GetRowLen(l), expression.GetRowLen(r)
if lLen == 1 && rLen == 1 {
return er.newFunction(op, types.NewFieldType(mysql.TypeTiny), l, r)
} else if rLen != lLen {
return nil, expression.ErrOperandColumns.GenWithStackByArgs(lLen)
}
switch op {
case ast.EQ, ast.NE, ast.NullEQ:
funcs := make([]expression.Expression, lLen)
for i := range lLen {
var err error
funcs[i], err = er.constructBinaryOpFunction(expression.GetFuncArg(l, i), expression.GetFuncArg(r, i), op)
if err != nil {
return nil, err
}
}
if op == ast.NE {
return expression.ComposeDNFCondition(er.sctx, funcs...), nil
}
return expression.ComposeCNFCondition(er.sctx, funcs...), nil
default:
/*
The algorithm is as follows:
1. Iterate over i left columns and construct his own CNF for each left column.
1.1 Iterate over j (every i-1 columns) to l[j]=r[j]
1.2 Build current i column with op to l[i] op r[i]
1.3 Combine 1.1 and 1.2 predicates with AND operator
2. Combine every i CNF with OR operator.
*/
resultDNFList := make([]expression.Expression, 0, lLen)
// Step 1
for i := range lLen {
exprList := make([]expression.Expression, 0, i+1)
// Step 1.1 build prefix equal conditions
// (l[0], ... , l[i-1], ...) op (r[0], ... , r[i-1], ...) should be convert to
// l[0] = r[0] and l[1] = r[1] and ... and l[i-1] = r[i-1]
for j := range i {
jExpr, err := er.constructBinaryOpFunction(expression.GetFuncArg(l, j), expression.GetFuncArg(r, j), ast.EQ)
if err != nil {
return nil, err
}
exprList = append(exprList, jExpr)
}
// Especially, op is GE or LE, the prefix element will be converted to > or <.
degeneratedOp := op
if i < lLen-1 {
switch op {
case ast.GE:
degeneratedOp = ast.GT
case ast.LE:
degeneratedOp = ast.LT
}
}
// Step 1.2
currentIndexExpr, err := er.constructBinaryOpFunction(expression.GetFuncArg(l, i), expression.GetFuncArg(r, i), degeneratedOp)
if err != nil {
return nil, err
}
exprList = append(exprList, currentIndexExpr)
// Step 1.3
currentExpr := expression.ComposeCNFCondition(er.sctx, exprList...)
resultDNFList = append(resultDNFList, currentExpr)
}
// Step 2
return expression.ComposeDNFCondition(er.sctx, resultDNFList...), nil
}
}
// buildSubquery translates the subquery ast to plan.
// Subquery related hints are returned through hintFlags. Please see comments around HintFlagSemiJoinRewrite and PlanBuilder.subQueryHintFlags for details.
func (er *expressionRewriter) buildSubquery(ctx context.Context, planCtx *exprRewriterPlanCtx, subq *ast.SubqueryExpr, subqueryCtx subQueryCtx) (np base.LogicalPlan, hintFlags uint64, err error) {
intest.AssertNotNil(planCtx)
b := planCtx.builder
if er.schema != nil {
outerSchema := er.schema.Clone()
b.outerSchemas = append(b.outerSchemas, outerSchema)
b.outerNames = append(b.outerNames, er.names)
b.outerBlockExpand = append(b.outerBlockExpand, b.currentBlockExpand)
// set it to nil, otherwise, inner qb will use outer expand meta to rewrite expressions.
b.currentBlockExpand = nil
defer func() {
b.outerSchemas = b.outerSchemas[0 : len(b.outerSchemas)-1]
b.outerNames = b.outerNames[0 : len(b.outerNames)-1]
b.currentBlockExpand = b.outerBlockExpand[len(b.outerBlockExpand)-1]
b.outerBlockExpand = b.outerBlockExpand[0 : len(b.outerBlockExpand)-1]
}()
}
// Store the old value before we enter the subquery and reset they to default value.
oldSubQCtx := b.subQueryCtx
b.subQueryCtx = subqueryCtx
oldHintFlags := b.subQueryHintFlags
b.subQueryHintFlags = 0
outerWindowSpecs := b.windowSpecs
defer func() {
b.windowSpecs = outerWindowSpecs
b.subQueryCtx = oldSubQCtx
b.subQueryHintFlags = oldHintFlags
}()
np, err = b.buildResultSetNode(ctx, subq.Query, false)
if err != nil {
return nil, 0, err
}
hintFlags = b.subQueryHintFlags
// Pop the handle map generated by the subquery.
b.handleHelper.popMap()
return np, hintFlags, nil
}
func (er *expressionRewriter) requirePlanCtx(inNode ast.Node, detail string) (ctx *exprRewriterPlanCtx, err error) {
if ctx = er.planCtx; ctx == nil {
if detail != "" {
detail = ", " + detail
}
err = errors.Errorf("planCtx is required when rewriting node: '%T'%s", inNode, detail)
}
return
}
// Enter implements Visitor interface.
func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
enterWithPlanCtx := func(fn func(*exprRewriterPlanCtx) (ast.Node, bool)) (ast.Node, bool) {
planCtx, err := er.requirePlanCtx(inNode, "")
if err != nil {
er.err = err
return inNode, true
}
return fn(planCtx)
}
switch v := inNode.(type) {
case *ast.AggregateFuncExpr:
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
index, ok := -1, false
if planCtx.aggrMap != nil {
index, ok = planCtx.aggrMap[v]
}
if ok {
// index < 0 indicates this is a correlated aggregate belonging to outer query,
// for which a correlated column will be created later, so we append a null constant
// as a temporary result expression.
if index < 0 {
er.ctxStackAppend(expression.NewNull(), types.EmptyName)
} else {
// index >= 0 indicates this is a regular aggregate column
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
}
return inNode, true
}
// replace correlated aggregate in sub-query with its corresponding correlated column
if col, ok := planCtx.builder.correlatedAggMapper[v]; ok {
er.ctxStackAppend(col, types.EmptyName)
return inNode, true
}
er.err = plannererrors.ErrInvalidGroupFuncUse
return inNode, true
})
case *ast.ColumnNameExpr:
if planCtx := er.planCtx; planCtx != nil {
if index, ok := planCtx.builder.colMapper[v]; ok {
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
return inNode, true
}
}
case *ast.CompareSubqueryExpr:
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
return er.handleCompareSubquery(er.ctx, planCtx, v)
})
case *ast.ExistsSubqueryExpr:
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
return er.handleExistSubquery(er.ctx, planCtx, v)
})
case *ast.PatternInExpr:
if v.Sel != nil {
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
return er.handleInSubquery(er.ctx, planCtx, v)
})
}
if len(v.List) != 1 {
break
}
// For 10 in ((select * from t)), the parser won't set v.Sel.
// So we must process this case here.
x := v.List[0]
for {
switch y := x.(type) {
case *ast.SubqueryExpr:
v.Sel = y
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
return er.handleInSubquery(er.ctx, planCtx, v)
})
case *ast.ParenthesesExpr:
x = y.Expr
default:
// Expect its left and right child to be a scalar value.
er.asScalar = true
return inNode, false
}
}
case *ast.SubqueryExpr:
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
return er.handleScalarSubquery(er.ctx, planCtx, v)
})
case *ast.ParenthesesExpr:
case *ast.ValuesExpr:
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
schema, names := er.schema, er.names
// NOTE: "er.insertPlan != nil" means that we are rewriting the
// expressions inside the assignment of "INSERT" statement. we have to
// use the "tableSchema" of that "insertPlan".
if planCtx.insertPlan != nil {
schema = planCtx.insertPlan.TableSchema
names = planCtx.insertPlan.TableColNames
}
idx, err := expression.FindFieldName(names, v.Column.Name)
if err != nil {
er.err = err
return inNode, false
}
if idx < 0 {
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Column.Name.OrigColName(), "field list")
return inNode, false
}
col := schema.Columns[idx]
er.ctxStackAppend(expression.NewValuesFunc(er.sctx, col.Index, col.RetType), types.EmptyName)
return inNode, true
})
case *ast.WindowFuncExpr:
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
intest.AssertNotNil(planCtx)
index, ok := -1, false
if planCtx.windowMap != nil {
index, ok = planCtx.windowMap[v]
}
if !ok {
er.err = plannererrors.ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(strings.ToLower(v.Name))
return inNode, true
}
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
return inNode, true
})
case *ast.FuncCallExpr:
er.asScalar = true
if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
er.disableFoldCounter++
}
if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok {
er.tryFoldCounter++
}
case *ast.CaseExpr:
er.asScalar = true
if _, ok := expression.DisableFoldFunctions["case"]; ok {
er.disableFoldCounter++
}
if _, ok := expression.TryFoldFunctions["case"]; ok {
er.tryFoldCounter++
}
case *ast.BinaryOperationExpr:
er.asScalar = true
if v.Op == opcode.LogicAnd || v.Op == opcode.LogicOr {
er.tryFoldCounter++
}
case *ast.SetCollationExpr:
// Do nothing
default:
er.asScalar = true
}
return inNode, false
}
func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np base.LogicalPlan, planCtx *exprRewriterPlanCtx, l, r expression.Expression, not, markNoDecorrelate bool) {
intest.AssertNotNil(planCtx)
if er.asScalar || not {
if expression.GetRowLen(r) == 1 {
rCol := r.(*expression.Column)
// If both input columns of `!= all / = any` expression are not null, we can treat the expression
// as normal column equal condition.
if !expression.ExprNotNull(er.sctx.GetEvalCtx(), l) || !expression.ExprNotNull(er.sctx.GetEvalCtx(), rCol) {
rColCopy := *rCol
rColCopy.InOperand = true
r = &rColCopy
l = expression.SetExprColumnInOperand(l)
}
} else {
rowFunc := r.(*expression.ScalarFunction)
rargs := rowFunc.GetArgs()
args := make([]expression.Expression, 0, len(rargs))
modified := false
for i, rarg := range rargs {
larg := expression.GetFuncArg(l, i)
if !expression.ExprNotNull(er.sctx.GetEvalCtx(), larg) || !expression.ExprNotNull(er.sctx.GetEvalCtx(), rarg) {
rCol := rarg.(*expression.Column)
rColCopy := *rCol
rColCopy.InOperand = true
rarg = &rColCopy
modified = true
}
args = append(args, rarg)
}
if modified {
r, er.err = er.newFunction(ast.RowFunc, args[0].GetType(er.sctx.GetEvalCtx()), args...)
if er.err != nil {
return
}
l = expression.SetExprColumnInOperand(l)
}
}
}
var condition expression.Expression
condition, er.err = er.constructBinaryOpFunction(l, r, ast.EQ)
if er.err != nil {
return
}
planCtx.plan, er.err = planCtx.builder.buildSemiApply(planCtx.plan, np, []expression.Expression{condition}, er.asScalar, not, false, markNoDecorrelate)
}
func (er *expressionRewriter) handleCompareSubquery(ctx context.Context, planCtx *exprRewriterPlanCtx, v *ast.CompareSubqueryExpr) (ast.Node, bool) {
intest.AssertNotNil(planCtx)
b := planCtx.builder
ci := b.prepareCTECheckForSubQuery()
defer resetCTECheckForSubQuery(ci)
v.L.Accept(er)
if er.err != nil {
return v, true
}
lexpr := er.ctxStack[len(er.ctxStack)-1]
subq, ok := v.R.(*ast.SubqueryExpr)
if !ok {
er.err = errors.Errorf("Unknown compare type %T", v.R)
return v, true
}
np, hintFlags, err := er.buildSubquery(ctx, planCtx, subq, handlingCompareSubquery)
if err != nil {
er.err = err
return v, true
}
corCols := coreusage.ExtractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema())
noDecorrelate := isNoDecorrelate(planCtx, corCols, hintFlags, handlingCompareSubquery)
// Only (a,b,c) = any (...) and (a,b,c) != all (...) can use row expression.
canMultiCol := (!v.All && v.Op == opcode.EQ) || (v.All && v.Op == opcode.NE)
if !canMultiCol && (expression.GetRowLen(lexpr) != 1 || np.Schema().Len() != 1) {
er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
return v, true
}
lLen := expression.GetRowLen(lexpr)
if lLen != np.Schema().Len() {
er.err = expression.ErrOperandColumns.GenWithStackByArgs(lLen)
return v, true
}
var rexpr expression.Expression
if np.Schema().Len() == 1 {
rexpr = np.Schema().Columns[0]
} else {
args := make([]expression.Expression, 0, np.Schema().Len())
for _, col := range np.Schema().Columns {
args = append(args, col)
}
rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(er.sctx.GetEvalCtx()), args...)
if er.err != nil {
return v, true
}
}
// Lexpr cannot compare with rexpr by different collate
opString := new(strings.Builder)
v.Op.Format(opString)
_, er.err = expression.CheckAndDeriveCollationFromExprs(er.sctx, opString.String(), types.ETInt, lexpr, rexpr)
if er.err != nil {
return v, true
}
switch v.Op {
// Only EQ, NE and NullEQ can be composed with and.
case opcode.EQ, opcode.NE, opcode.NullEQ:
if v.Op == opcode.EQ {
if v.All {
er.handleEQAll(planCtx, lexpr, rexpr, np, noDecorrelate)
} else {
// `a = any(subq)` will be rewriten as `a in (subq)`.
er.asScalar = true
er.buildSemiApplyFromEqualSubq(np, planCtx, lexpr, rexpr, false, noDecorrelate)
if er.err != nil {
return v, true
}
}
} else if v.Op == opcode.NE {
if v.All {
// `a != all(subq)` will be rewriten as `a not in (subq)`.
er.asScalar = true
er.buildSemiApplyFromEqualSubq(np, planCtx, lexpr, rexpr, true, noDecorrelate)
if er.err != nil {
return v, true
}
} else {
er.handleNEAny(planCtx, lexpr, rexpr, np, noDecorrelate)
}
} else {
// TODO: Support this in future.
er.err = errors.New("We don't support <=> all or <=> any now")
return v, true
}
default:
// When < all or > any , the agg function should use min.
useMin := ((v.Op == opcode.LT || v.Op == opcode.LE) && v.All) || ((v.Op == opcode.GT || v.Op == opcode.GE) && !v.All)
er.handleOtherComparableSubq(planCtx, lexpr, rexpr, np, useMin, v.Op.String(), v.All, noDecorrelate)
}
if er.asScalar {
// The parent expression only use the last column in schema, which represents whether the condition is matched.
er.ctxStack[len(er.ctxStack)-1] = planCtx.plan.Schema().Columns[planCtx.plan.Schema().Len()-1]
er.ctxNameStk[len(er.ctxNameStk)-1] = planCtx.plan.OutputNames()[planCtx.plan.Schema().Len()-1]
}
return v, true
}
// handleOtherComparableSubq handles the queries like < any, < max, etc. For example, if the query is t.id < any (select s.id from s),
// it will be rewrote to t.id < (select max(s.id) from s).
func (er *expressionRewriter) handleOtherComparableSubq(planCtx *exprRewriterPlanCtx, lexpr, rexpr expression.Expression, np base.LogicalPlan, useMin bool, cmpFunc string, all, markNoDecorrelate bool) {
intest.AssertNotNil(planCtx)
plan4Agg := logicalop.LogicalAggregation{}.Init(planCtx.builder.ctx, planCtx.builder.getSelectOffset())
if hintinfo := planCtx.builder.TableHints(); hintinfo != nil {
plan4Agg.PreferAggType = hintinfo.PreferAggType
plan4Agg.PreferAggToCop = hintinfo.PreferAggToCop
}
plan4Agg.SetChildren(np)
// Create a "max" or "min" aggregation.
funcName := ast.AggFuncMax
if useMin {
funcName = ast.AggFuncMin
}
funcMaxOrMin, err := aggregation.NewAggFuncDesc(planCtx.builder.ctx.GetExprCtx(), funcName, []expression.Expression{rexpr}, false)
if err != nil {
er.err = err
return
}
// Create a column and append it to the schema of that aggregation.
colMaxOrMin := &expression.Column{
UniqueID: planCtx.builder.ctx.GetSessionVars().AllocPlanColumnID(),
RetType: funcMaxOrMin.RetTp,
}
colMaxOrMin.SetCoercibility(rexpr.Coercibility())
schema := expression.NewSchema(colMaxOrMin)
plan4Agg.SetOutputNames(append(plan4Agg.OutputNames(), types.EmptyName))
plan4Agg.SetSchema(schema)
plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin}
cond := expression.NewFunctionInternal(er.sctx, cmpFunc, types.NewFieldType(mysql.TypeTiny), lexpr, colMaxOrMin)
er.buildQuantifierPlan(planCtx, plan4Agg, cond, lexpr, rexpr, all, markNoDecorrelate)
}
// buildQuantifierPlan adds extra condition for any / all subquery.
func (er *expressionRewriter) buildQuantifierPlan(planCtx *exprRewriterPlanCtx, plan4Agg *logicalop.LogicalAggregation, cond, lexpr, rexpr expression.Expression, all, markNoDecorrelate bool) {
intest.AssertNotNil(planCtx)
innerIsNull := expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr)
outerIsNull := expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr)
exprCtx := planCtx.builder.ctx.GetExprCtx()
funcSum, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false)
if err != nil {
er.err = err
return
}
sessVars := planCtx.builder.ctx.GetSessionVars()
colSum := &expression.Column{
UniqueID: sessVars.AllocPlanColumnID(),
RetType: funcSum.RetTp,
}
plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum)
plan4Agg.Schema().Append(colSum)
innerHasNull := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.NewZero())
// Build `count(1)` aggregation to check if subquery is empty.
funcCount, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncCount, []expression.Expression{expression.NewOne()}, false)
if err != nil {
er.err = err
return
}
colCount := &expression.Column{
UniqueID: sessVars.AllocPlanColumnID(),
RetType: funcCount.RetTp,
}
plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount)
plan4Agg.Schema().Append(colCount)
if all {
// All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it
// should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true).
innerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.NewNull(), expression.NewOne())
cond = expression.ComposeCNFCondition(er.sctx, cond, innerNullChecker)
// If the subquery is empty, it should always return true.
emptyChecker := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.NewZero())
// If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`.
outerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.NewNull(), expression.NewZero())
cond = expression.ComposeDNFCondition(er.sctx, cond, emptyChecker, outerNullChecker)
} else {
// For "any" expression, if the subquery has null and the cond returns false, the result should be NULL.
// Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)`
innerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.NewNull(), expression.NewZero())
cond = expression.ComposeDNFCondition(er.sctx, cond, innerNullChecker)
// If the subquery is empty, it should always return false.
emptyChecker := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colCount, expression.NewZero())
// If outer key is null, and subquery is not empty, it should return null.
outerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.NewNull(), expression.NewOne())
cond = expression.ComposeCNFCondition(er.sctx, cond, emptyChecker, outerNullChecker)
}
// TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions.
// plan4Agg.buildProjectionIfNecessary()
if !er.asScalar {
// For Semi LogicalApply without aux column, the result is no matter false or null. So we can add it to join predicate.
planCtx.plan, er.err = planCtx.builder.buildSemiApply(planCtx.plan, plan4Agg, []expression.Expression{cond}, false, false, false, markNoDecorrelate)
return
}
// If we treat the result as a scalar value, we will add a projection with a extra column to output true, false or null.
outerSchemaLen := planCtx.plan.Schema().Len()
planCtx.plan = planCtx.builder.buildApplyWithJoinType(planCtx.plan, plan4Agg, base.InnerJoin, markNoDecorrelate)
joinSchema := planCtx.plan.Schema()
proj := logicalop.LogicalProjection{
Exprs: expression.Column2Exprs(joinSchema.Clone().Columns[:outerSchemaLen]),
}.Init(planCtx.builder.ctx, planCtx.builder.getSelectOffset())
proj.SetOutputNames(make([]*types.FieldName, outerSchemaLen, outerSchemaLen+1))
names := proj.OutputNames()
copy(names, planCtx.plan.OutputNames())
proj.SetOutputNames(names)
proj.SetSchema(expression.NewSchema(joinSchema.Clone().Columns[:outerSchemaLen]...))
proj.Exprs = append(proj.Exprs, cond)
proj.Schema().Append(&expression.Column{
UniqueID: sessVars.AllocPlanColumnID(),
RetType: cond.GetType(er.sctx.GetEvalCtx()),
})
proj.SetOutputNames(append(proj.OutputNames(), types.EmptyName))
proj.SetChildren(planCtx.plan)
planCtx.plan = proj
}
// handleNEAny handles the case of != any. For example, if the query is t.id != any (select s.id from s), it will be rewrote to
// t.id != s.id or count(distinct s.id) > 1 or [any checker]. If there are two different values in s.id ,
// there must exist a s.id that doesn't equal to t.id.
func (er *expressionRewriter) handleNEAny(planCtx *exprRewriterPlanCtx, lexpr, rexpr expression.Expression, np base.LogicalPlan, markNoDecorrelate bool) {
intest.AssertNotNil(planCtx)
sctx := planCtx.builder.ctx
exprCtx := sctx.GetExprCtx()
// If there is NULL in s.id column, s.id should be the value that isn't null in condition t.id != s.id.
// So use function max to filter NULL.
maxFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncMax, []expression.Expression{rexpr}, false)
if err != nil {
er.err = err
return
}
countFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncCount, []expression.Expression{rexpr}, true)
if err != nil {
er.err = err
return
}
plan4Agg := logicalop.LogicalAggregation{
AggFuncs: []*aggregation.AggFuncDesc{maxFunc, countFunc},
}.Init(sctx, planCtx.builder.getSelectOffset())
if hintinfo := planCtx.builder.TableHints(); hintinfo != nil {
plan4Agg.PreferAggType = hintinfo.PreferAggType
plan4Agg.PreferAggToCop = hintinfo.PreferAggToCop
}
plan4Agg.SetChildren(np)
maxResultCol := &expression.Column{
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
RetType: maxFunc.RetTp,
}
maxResultCol.SetCoercibility(rexpr.Coercibility())
count := &expression.Column{
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
RetType: countFunc.RetTp,
}
plan4Agg.SetOutputNames(append(plan4Agg.OutputNames(), types.EmptyName, types.EmptyName))
plan4Agg.SetSchema(expression.NewSchema(maxResultCol, count))
gtFunc := expression.NewFunctionInternal(er.sctx, ast.GT, types.NewFieldType(mysql.TypeTiny), count, expression.NewOne())
neCond := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), lexpr, maxResultCol)
cond := expression.ComposeDNFCondition(er.sctx, gtFunc, neCond)
er.buildQuantifierPlan(planCtx, plan4Agg, cond, lexpr, rexpr, false, markNoDecorrelate)
}
// handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to
// t.id = (select s.id from s having count(distinct s.id) <= 1 and [all checker]).
func (er *expressionRewriter) handleEQAll(planCtx *exprRewriterPlanCtx, lexpr, rexpr expression.Expression, np base.LogicalPlan, markNoDecorrelate bool) {
intest.AssertNotNil(planCtx)
sctx := planCtx.builder.ctx
exprCtx := sctx.GetExprCtx()
// If there is NULL in s.id column, s.id should be the value that isn't null in condition t.id == s.id.
// So use function max to filter NULL.
maxFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncMax, []expression.Expression{rexpr}, false)
if err != nil {
er.err = err
return
}
countFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncCount, []expression.Expression{rexpr}, true)
if err != nil {
er.err = err
return
}
plan4Agg := logicalop.LogicalAggregation{
AggFuncs: []*aggregation.AggFuncDesc{maxFunc, countFunc},
}.Init(sctx, planCtx.builder.getSelectOffset())
if hintinfo := planCtx.builder.TableHints(); hintinfo != nil {
plan4Agg.PreferAggType = hintinfo.PreferAggType
plan4Agg.PreferAggToCop = hintinfo.PreferAggToCop
}
plan4Agg.SetChildren(np)
plan4Agg.SetOutputNames(append(plan4Agg.OutputNames(), types.EmptyName))
maxResultCol := &expression.Column{
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
RetType: maxFunc.RetTp,
}
maxResultCol.SetCoercibility(rexpr.Coercibility())
plan4Agg.SetOutputNames(append(plan4Agg.OutputNames(), types.EmptyName))
count := &expression.Column{
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
RetType: countFunc.RetTp,
}
plan4Agg.SetSchema(expression.NewSchema(maxResultCol, count))
leFunc := expression.NewFunctionInternal(er.sctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.NewOne())
eqCond := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, maxResultCol)
cond := expression.ComposeCNFCondition(er.sctx, leFunc, eqCond)
er.buildQuantifierPlan(planCtx, plan4Agg, cond, lexpr, rexpr, true, markNoDecorrelate)
}
func (er *expressionRewriter) handleExistSubquery(ctx context.Context, planCtx *exprRewriterPlanCtx, v *ast.ExistsSubqueryExpr) (ast.Node, bool) {
intest.AssertNotNil(planCtx)
b := planCtx.builder
ci := b.prepareCTECheckForSubQuery()
defer resetCTECheckForSubQuery(ci)
subq, ok := v.Sel.(*ast.SubqueryExpr)
if !ok {
er.err = errors.Errorf("Unknown exists type %T", v.Sel)
return v, true
}
np, hintFlags, err := er.buildSubquery(ctx, planCtx, subq, handlingExistsSubquery)
if err != nil {
er.err = err
return v, true
}
// Add LIMIT 1 when noDecorrelate is true for EXISTS subqueries to enable early exit
corCols := coreusage.ExtractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema())
noDecorrelate := isNoDecorrelate(planCtx, corCols, hintFlags, handlingExistsSubquery)
if noDecorrelate {
// Only add LIMIT 1 if the query doesn't already contain a LIMIT clause
if !hasLimit(np) {
limitClause := &ast.Limit{
Count: ast.NewValueExpr(1, "", ""),
}
var err error
np, err = planCtx.builder.buildLimit(np, limitClause)
if err != nil {
er.err = err
return v, true
}
}
}
np = er.popExistsSubPlan(planCtx, np)
semiJoinRewrite := hintFlags&hint.HintFlagSemiJoinRewrite > 0
if semiJoinRewrite && noDecorrelate {
b.ctx.GetSessionVars().StmtCtx.SetHintWarning(
"NO_DECORRELATE() and SEMI_JOIN_REWRITE() are in conflict. Both will be ineffective.")
noDecorrelate = false
semiJoinRewrite = false
}
if b.disableSubQueryPreprocessing || len(coreusage.ExtractCorrelatedCols4LogicalPlan(np)) > 0 || hasCTEConsumerInSubPlan(np) {
planCtx.plan, er.err = b.buildSemiApply(planCtx.plan, np, nil, er.asScalar, v.Not, semiJoinRewrite, noDecorrelate)
if er.err != nil || !er.asScalar {
return v, true
}
er.ctxStackAppend(planCtx.plan.Schema().Columns[planCtx.plan.Schema().Len()-1], planCtx.plan.OutputNames()[planCtx.plan.Schema().Len()-1])
} else {
// We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily.
nthPlanBackup := b.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan
b.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = -1
physicalPlan, _, err := DoOptimize(ctx, planCtx.builder.ctx, b.optFlag, np)
b.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = nthPlanBackup
if err != nil {
er.err = err
return v, true
}
if b.ctx.GetSessionVars().StmtCtx.InExplainStmt && !b.ctx.GetSessionVars().StmtCtx.InExplainAnalyzeStmt && b.ctx.GetSessionVars().ExplainNonEvaledSubQuery {
newColID := b.ctx.GetSessionVars().AllocPlanColumnID()
subqueryCtx := ScalarSubqueryEvalCtx{
scalarSubQuery: physicalPlan,
ctx: ctx,
is: b.is,
outputColIDs: []int64{newColID},
}.Init(b.ctx, np.QueryBlockOffset())
scalarSubQ := &ScalarSubQueryExpr{
scalarSubqueryColID: newColID,
evalCtx: subqueryCtx,
}
scalarSubQ.RetType = np.Schema().Columns[0].GetType(er.sctx.GetEvalCtx())
scalarSubQ.SetCoercibility(np.Schema().Columns[0].Coercibility())
b.ctx.GetSessionVars().RegisterScalarSubQ(subqueryCtx)
if v.Not {
notWrapped, err := expression.NewFunction(b.ctx.GetExprCtx(), ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), scalarSubQ)
if err != nil {
er.err = err
return v, true
}
er.ctxStackAppend(notWrapped, types.EmptyName)
return v, true
}
er.ctxStackAppend(scalarSubQ, types.EmptyName)
return v, true
}
// register the subquery plan but continue with normal execution
subqueryCtx := ScalarSubqueryEvalCtx{
scalarSubQuery: physicalPlan,
ctx: ctx,
is: b.is,
}.Init(planCtx.builder.ctx, np.QueryBlockOffset())
newColIDs := make([]int64, 0, np.Schema().Len())
for range np.Schema().Columns {
newColID := planCtx.builder.ctx.GetSessionVars().AllocPlanColumnID()
newColIDs = append(newColIDs, newColID)
}
subqueryCtx.outputColIDs = newColIDs
planCtx.builder.ctx.GetSessionVars().RegisterScalarSubQ(subqueryCtx)
row, err := EvalSubqueryFirstRow(ctx, physicalPlan, b.is, b.ctx)
if err != nil {
er.err = err
return v, true
}
if (row != nil && !v.Not) || (row == nil && v.Not) {
er.ctxStackAppend(expression.NewSignedOne(), types.EmptyName)
} else {
er.ctxStackAppend(expression.NewSignedZero(), types.EmptyName)
}
}
return v, true
}
// popExistsSubPlan will remove the useless plan in exist's child.
// See comments inside the method for more details.
func (*expressionRewriter) popExistsSubPlan(planCtx *exprRewriterPlanCtx, p base.LogicalPlan) base.LogicalPlan {
intest.AssertNotNil(planCtx)
out:
for {
switch plan := p.(type) {
// This can be removed when in exists clause,
// e.g. exists(select count(*) from t order by a) is equal to exists t.
case *logicalop.LogicalProjection, *logicalop.LogicalSort:
p = p.Children()[0]
case *logicalop.LogicalAggregation:
if len(plan.GroupByItems) == 0 {
p = logicalop.LogicalTableDual{RowCount: 1}.Init(planCtx.builder.ctx, planCtx.builder.getSelectOffset())
break out
}
p = p.Children()[0]
default:
break out
}
}
return p
}
func (er *expressionRewriter) handleInSubquery(ctx context.Context, planCtx *exprRewriterPlanCtx, v *ast.PatternInExpr) (ast.Node, bool) {
intest.AssertNotNil(planCtx)
ci := planCtx.builder.prepareCTECheckForSubQuery()
defer resetCTECheckForSubQuery(ci)
asScalar := er.asScalar
er.asScalar = true
v.Expr.Accept(er)
if er.err != nil {
return v, true
}
lexpr := er.ctxStack[len(er.ctxStack)-1]
subq, ok := v.Sel.(*ast.SubqueryExpr)
if !ok {
er.err = errors.Errorf("Unknown compare type %T", v.Sel)
return v, true
}
np, hintFlags, err := er.buildSubquery(ctx, planCtx, subq, handlingInSubquery)
if err != nil {
er.err = err
return v, true
}
lLen := expression.GetRowLen(lexpr)
if lLen != np.Schema().Len() {
er.err = expression.ErrOperandColumns.GenWithStackByArgs(lLen)
return v, true
}
var rexpr expression.Expression
if np.Schema().Len() == 1 {
rexpr = np.Schema().Columns[0]
rCol := rexpr.(*expression.Column)
// For AntiSemiJoin/LeftOuterSemiJoin/AntiLeftOuterSemiJoin, we cannot treat `in` expression as
// normal column equal condition, so we specially mark the inner operand here.
if v.Not || asScalar {
// If both input columns of `in` expression are not null, we can treat the expression
// as normal column equal condition instead. Otherwise, mark the left and right side.
// eg: for some optimization, the column substitute in right side in projection elimination
// will cause case like <lcol EQ rcol(inOperand)> as <lcol EQ constant> which is not
// a valid null-aware EQ. (null in lcol still need to be null-aware)
if !expression.ExprNotNull(er.sctx.GetEvalCtx(), lexpr) || !expression.ExprNotNull(er.sctx.GetEvalCtx(), rCol) {
rColCopy := *rCol
rColCopy.InOperand = true
rexpr = &rColCopy
lexpr = expression.SetExprColumnInOperand(lexpr)
}
}
} else {
args := make([]expression.Expression, 0, np.Schema().Len())
for i, col := range np.Schema().Columns {
if v.Not || asScalar {
larg := expression.GetFuncArg(lexpr, i)
// If both input columns of `in` expression are not null, we can treat the expression
// as normal column equal condition instead. Otherwise, mark the left and right side.
if !expression.ExprNotNull(er.sctx.GetEvalCtx(), larg) || !expression.ExprNotNull(er.sctx.GetEvalCtx(), col) {
rarg := *col
rarg.InOperand = true
col = &rarg
if larg != nil {
lexpr.(*expression.ScalarFunction).GetArgs()[i] = expression.SetExprColumnInOperand(larg)
}
}
}
args = append(args, col)
}
rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(er.sctx.GetEvalCtx()), args...)
if er.err != nil {
return v, true
}
}
checkCondition, err := er.constructBinaryOpFunction(lexpr, rexpr, ast.EQ)
if err != nil {
er.err = err
return v, true
}
// If the leftKey and the rightKey have different collations, don't convert the sub-query to an inner-join
// since when converting we will add a distinct-agg upon the right child and this distinct-agg doesn't have the right collation.
// To keep it simple, we forbid this converting if they have different collations.
// tested by TestCollateSubQuery.
lt, rt := lexpr.GetType(er.sctx.GetEvalCtx()), rexpr.GetType(er.sctx.GetEvalCtx())
collFlag := collate.CompatibleCollate(lt.GetCollate(), rt.GetCollate())
corCols := coreusage.ExtractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema())
noDecorrelate := isNoDecorrelate(planCtx, corCols, hintFlags, handlingInSubquery)
// If it's not the form of `not in (SUBQUERY)`,
// and has no correlated column from the current level plan(if the correlated column is from upper level,
// we can treat it as constant, because the upper LogicalApply cannot be eliminated since current node is a join node),
// and don't need to append a scalar value, we can rewrite it to inner join.
if planCtx.builder.ctx.GetSessionVars().GetAllowInSubqToJoinAndAgg() && !v.Not && !asScalar && len(corCols) == 0 && collFlag {
// We need to try to eliminate the agg and the projection produced by this operation.
planCtx.builder.optFlag |= rule.FlagEliminateAgg
planCtx.builder.optFlag |= rule.FlagEliminateProjection
planCtx.builder.optFlag |= rule.FlagJoinReOrder
planCtx.builder.optFlag |= rule.FlagEmptySelectionEliminator
// Build distinct for the inner query.
agg, err := planCtx.builder.buildDistinct(np, np.Schema().Len())
if err != nil {
er.err = err
return v, true
}
// Build inner join above the aggregation.
join := logicalop.LogicalJoin{JoinType: base.InnerJoin}.Init(planCtx.builder.ctx, planCtx.builder.getSelectOffset())
join.SetChildren(planCtx.plan, agg)
join.SetSchema(expression.MergeSchema(planCtx.plan.Schema(), agg.Schema()))
join.SetOutputNames(make([]*types.FieldName, planCtx.plan.Schema().Len()+agg.Schema().Len()))
copy(join.OutputNames(), planCtx.plan.OutputNames())
copy(join.OutputNames()[planCtx.plan.Schema().Len():], agg.OutputNames())
join.AttachOnConds(expression.SplitCNFItems(checkCondition))
// set FullSchema and FullNames for this join
if left, ok := planCtx.plan.(*logicalop.LogicalJoin); ok && left.FullSchema != nil {
join.FullSchema = left.FullSchema
join.FullNames = left.FullNames
}
// Set join hint for this join.
if planCtx.builder.TableHints() != nil {
join.SetPreferredJoinTypeAndOrder(planCtx.builder.TableHints())
}
planCtx.plan = join
} else {
semiRewrite := hintFlags&hint.HintFlagSemiJoinRewrite > 0
planCtx.plan, er.err = planCtx.builder.buildSemiApply(planCtx.plan, np, expression.SplitCNFItems(checkCondition), asScalar, v.Not, semiRewrite, noDecorrelate)
if er.err != nil {
return v, true
}
}
er.ctxStackPop(1)
if asScalar {
col := planCtx.plan.Schema().Columns[planCtx.plan.Schema().Len()-1]
er.ctxStackAppend(col, planCtx.plan.OutputNames()[planCtx.plan.Schema().Len()-1])
}
return v, true
}
func isNoDecorrelate(planCtx *exprRewriterPlanCtx, corCols []*expression.CorrelatedColumn, hintFlags uint64, sCtx subQueryCtx) bool {
noDecorrelate := hintFlags&hint.HintFlagNoDecorrelate > 0
if len(corCols) == 0 {
if noDecorrelate {
planCtx.builder.ctx.GetSessionVars().StmtCtx.SetHintWarning(
"NO_DECORRELATE() is inapplicable because there are no correlated columns.")
noDecorrelate = false
}
} else {
semiJoinRewrite := hintFlags&hint.HintFlagSemiJoinRewrite > 0
// We can't override noDecorrelate via the variable for EXISTS subqueries with semi join rewrite
// as this will cause a conflict that will result in both being disabled in later code
// SemiJoinRewrite does not check the variable TiDBOptEnableSemiJoinRewrite.
// If that variable is enabled - we can still choose NOT to decorrelate here.
if !(semiJoinRewrite && sCtx == handlingExistsSubquery) {
// Only support scalar and exists subqueries
validSubqType := sCtx == handlingScalarSubquery || sCtx == handlingExistsSubquery
if validSubqType && planCtx.curClause == fieldList { // subquery is in the select list
planCtx.builder.ctx.GetSessionVars().RecordRelevantOptVar(vardef.TiDBOptEnableNoDecorrelateInSelect)
// If it isn't already enabled via hint, and variable is set, then enable it
if !noDecorrelate && planCtx.builder.ctx.GetSessionVars().EnableNoDecorrelateInSelect {
noDecorrelate = true
}
}
}
}
return noDecorrelate
}
func (er *expressionRewriter) handleScalarSubquery(ctx context.Context, planCtx *exprRewriterPlanCtx, v *ast.SubqueryExpr) (ast.Node, bool) {
intest.AssertNotNil(planCtx)
ci := planCtx.builder.prepareCTECheckForSubQuery()
defer resetCTECheckForSubQuery(ci)
np, hintFlags, err := er.buildSubquery(ctx, planCtx, v, handlingScalarSubquery)
if err != nil {
er.err = err
return v, true
}
np = planCtx.builder.buildMaxOneRow(np)
correlatedColumn := coreusage.ExtractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema())
noDecorrelate := isNoDecorrelate(planCtx, correlatedColumn, hintFlags, handlingScalarSubquery)
if planCtx.builder.disableSubQueryPreprocessing || len(coreusage.ExtractCorrelatedCols4LogicalPlan(np)) > 0 || hasCTEConsumerInSubPlan(np) {
planCtx.plan = planCtx.builder.buildApplyWithJoinType(planCtx.plan, np, base.LeftOuterJoin, noDecorrelate)
if np.Schema().Len() > 1 {
newCols := make([]expression.Expression, 0, np.Schema().Len())
for _, col := range np.Schema().Columns {
newCols = append(newCols, col)
}
expr, err1 := er.newFunction(ast.RowFunc, newCols[0].GetType(er.sctx.GetEvalCtx()), newCols...)
if err1 != nil {
er.err = err1
return v, true
}
er.ctxStackAppend(expr, types.EmptyName)
} else {
er.ctxStackAppend(planCtx.plan.Schema().Columns[planCtx.plan.Schema().Len()-1], planCtx.plan.OutputNames()[planCtx.plan.Schema().Len()-1])
}
return v, true
}
// We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily.
nthPlanBackup := planCtx.builder.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan
planCtx.builder.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = -1
physicalPlan, _, err := DoOptimize(ctx, planCtx.builder.ctx, planCtx.builder.optFlag, np)
planCtx.builder.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = nthPlanBackup
if err != nil {
er.err = err
return v, true
}
if planCtx.builder.ctx.GetSessionVars().StmtCtx.InExplainStmt && !planCtx.builder.ctx.GetSessionVars().StmtCtx.InExplainAnalyzeStmt && planCtx.builder.ctx.GetSessionVars().ExplainNonEvaledSubQuery {
subqueryCtx := ScalarSubqueryEvalCtx{
scalarSubQuery: physicalPlan,
ctx: ctx,
is: planCtx.builder.is,
}.Init(planCtx.builder.ctx, np.QueryBlockOffset())
newColIDs := make([]int64, 0, np.Schema().Len())
newScalarSubQueryExprs := make([]expression.Expression, 0, np.Schema().Len())
for _, col := range np.Schema().Columns {
newColID := planCtx.builder.ctx.GetSessionVars().AllocPlanColumnID()
scalarSubQ := &ScalarSubQueryExpr{
scalarSubqueryColID: newColID,
evalCtx: subqueryCtx,
}
scalarSubQ.RetType = col.RetType
scalarSubQ.SetCoercibility(col.Coercibility())
newColIDs = append(newColIDs, newColID)
newScalarSubQueryExprs = append(newScalarSubQueryExprs, scalarSubQ)
}
subqueryCtx.outputColIDs = newColIDs
planCtx.builder.ctx.GetSessionVars().RegisterScalarSubQ(subqueryCtx)
if len(newScalarSubQueryExprs) == 1 {
er.ctxStackAppend(newScalarSubQueryExprs[0], types.EmptyName)
} else {
rowFunc, err := er.newFunction(ast.RowFunc, newScalarSubQueryExprs[0].GetType(er.sctx.GetEvalCtx()), newScalarSubQueryExprs...)
if err != nil {
er.err = err
return v, true
}
er.ctxStack = append(er.ctxStack, rowFunc)
}
return v, true
}
// register the subquery plan but continue with normal execution
subqueryCtx := ScalarSubqueryEvalCtx{
scalarSubQuery: physicalPlan,
ctx: ctx,
is: planCtx.builder.is,
}.Init(planCtx.builder.ctx, np.QueryBlockOffset())
newColIDs := make([]int64, 0, np.Schema().Len())
for range np.Schema().Columns {
newColID := planCtx.builder.ctx.GetSessionVars().AllocPlanColumnID()
newColIDs = append(newColIDs, newColID)
}
subqueryCtx.outputColIDs = newColIDs
planCtx.builder.ctx.GetSessionVars().RegisterScalarSubQ(subqueryCtx)
row, err := EvalSubqueryFirstRow(ctx, physicalPlan, planCtx.builder.is, planCtx.builder.ctx)
if err != nil {
er.err = err
return v, true
}
newCols := make([]expression.Expression, 0, np.Schema().Len())
for i, data := range row {
constant := &expression.Constant{
Value: data,
RetType: np.Schema().Columns[i].GetType(er.sctx.GetEvalCtx()),
SubqueryRefID: newColIDs[i],
}
constant.SetCoercibility(np.Schema().Columns[i].Coercibility())
newCols = append(newCols, constant)
}
if np.Schema().Len() > 1 {
expr, err := er.newFunction(ast.RowFunc, newCols[0].GetType(er.sctx.GetEvalCtx()), newCols...)
if err != nil {
er.err = err
return v, true
}
er.ctxStackAppend(expr, types.EmptyName)
} else {
er.ctxStackAppend(newCols[0], types.EmptyName)
}
return v, true
}
func hasCTEConsumerInSubPlan(p base.LogicalPlan) bool {
if _, ok := p.(*logicalop.LogicalCTE); ok {
return true
}
return slices.ContainsFunc(p.Children(), hasCTEConsumerInSubPlan)
}
func initConstantRepertoire(ctx expression.EvalContext, c *expression.Constant) {
c.SetRepertoire(expression.ASCII)
if c.GetType(ctx).EvalType() == types.ETString {
for _, b := range c.Value.GetBytes() {
// if any character in constant is not ascii, set the repertoire to UNICODE.
if b >= 0x80 {
c.SetRepertoire(expression.UNICODE)
break
}
}
}
}
func (er *expressionRewriter) adjustUTF8MB4Collation(tp *types.FieldType) {
if tp.GetFlag()&mysql.UnderScoreCharsetFlag > 0 && charset.CharsetUTF8MB4 == tp.GetCharset() {
tp.SetCollate(er.sctx.GetDefaultCollationForUTF8MB4())
}
}
// Leave implements Visitor interface.
func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok bool) {
if er.err != nil {
return retNode, false
}
var inNode = originInNode
if er.preprocess != nil {
inNode = er.preprocess(inNode)
}
withPlanCtx := func(fn func(*exprRewriterPlanCtx), detail string) {
planCtx, err := er.requirePlanCtx(inNode, detail)
if err != nil {
er.err = err
return
}
fn(planCtx)
}
switch v := inNode.(type) {
case *ast.AggregateFuncExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause,
*ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr, *ast.WindowFuncExpr, *ast.TableNameExpr:
case *driver.ValueExpr:
// set right not null flag for constant value
retType := v.Type.Clone()
switch v.Datum.Kind() {
case types.KindNull:
retType.DelFlag(mysql.NotNullFlag)
default:
retType.AddFlag(mysql.NotNullFlag)
}
v.Datum.SetValue(v.Datum.GetValue(), retType)
value := &expression.Constant{Value: v.Datum, RetType: retType}
initConstantRepertoire(er.sctx.GetEvalCtx(), value)
er.adjustUTF8MB4Collation(retType)
if er.err != nil {
return retNode, false
}
er.ctxStackAppend(value, types.EmptyName)
case *driver.ParamMarkerExpr:
er.toParamMarker(v)
case *ast.VariableExpr:
if v.IsSystem {
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
er.rewriteSystemVariable(planCtx, v)
}, "accessing system variable requires plan context")
} else {
er.rewriteUserVariable(v)
}
case *ast.FuncCallExpr:
switch v.FnName.L {
case ast.Grouping:
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
er.funcCallToExpressionWithPlanCtx(planCtx, v)
}, "grouping function requires plan context")
default:
if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok {
er.tryFoldCounter--
}
er.funcCallToExpression(v)
if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
er.disableFoldCounter--
}
}
case *ast.TableName:
er.toTable(v)
case *ast.ColumnName:
er.toColumn(v)
case *ast.UnaryOperationExpr:
er.unaryOpToExpression(v)
case *ast.BinaryOperationExpr:
if v.Op == opcode.LogicAnd || v.Op == opcode.LogicOr {
er.tryFoldCounter--
}
er.binaryOpToExpression(v)
case *ast.BetweenExpr:
er.betweenToExpression(v)
case *ast.CaseExpr:
if _, ok := expression.TryFoldFunctions["case"]; ok {
er.tryFoldCounter--
}
er.caseToExpression(v)
if _, ok := expression.DisableFoldFunctions["case"]; ok {
er.disableFoldCounter--
}
case *ast.FuncCastExpr:
if v.Tp.IsArray() && !er.allowBuildCastArray {
er.err = expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions")
return retNode, false
}
arg := er.ctxStack[len(er.ctxStack)-1]
er.err = expression.CheckArgsNotMultiColumnRow(arg)
if er.err != nil {
return retNode, false
}
// check the decimal precision of "CAST(AS TIME)".
er.err = er.checkTimePrecision(v.Tp)
if er.err != nil {
return retNode, false
}
castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp, false, v.ExplicitCharSet)
if err != nil {
er.err = err
return retNode, false
}
if v.Tp.EvalType() == types.ETString {
castFunction.SetCoercibility(expression.CoercibilityImplicit)
if v.Tp.GetCharset() == charset.CharsetASCII {
castFunction.SetRepertoire(expression.ASCII)
} else {
castFunction.SetRepertoire(expression.UNICODE)
}
} else {
castFunction.SetCoercibility(expression.CoercibilityNumeric)
castFunction.SetRepertoire(expression.ASCII)
}
er.ctxStack[len(er.ctxStack)-1] = castFunction
er.ctxNameStk[len(er.ctxNameStk)-1] = types.EmptyName
case *ast.JSONSumCrc32Expr:
arg := er.ctxStack[len(er.ctxStack)-1]
jsonSumFunction, err := expression.BuildJSONSumCrc32FunctionWithCheck(er.sctx, arg, v.Tp)
if err != nil {
er.err = err
return retNode, false
}
jsonSumFunction.SetCoercibility(expression.CoercibilityNumeric)
jsonSumFunction.SetRepertoire(expression.ASCII)
er.ctxStack[len(er.ctxStack)-1] = jsonSumFunction
er.ctxNameStk[len(er.ctxNameStk)-1] = types.EmptyName
case *ast.PatternLikeOrIlikeExpr:
er.patternLikeOrIlikeToExpression(v)
case *ast.PatternRegexpExpr:
er.regexpToScalarFunc(v)
case *ast.RowExpr:
er.rowToScalarFunc(v)
case *ast.PatternInExpr:
if v.Sel == nil {
er.inToExpression(len(v.List), v.Not, &v.Type)
}
case *ast.PositionExpr:
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
er.positionToScalarFunc(planCtx, v)
}, "")
case *ast.IsNullExpr:
er.isNullToExpression(v)
case *ast.IsTruthExpr:
er.isTrueToScalarFunc(v)
case *ast.DefaultExpr:
if planCtx := er.planCtx; planCtx != nil {
er.evalDefaultExprWithPlanCtx(planCtx, v)
} else if er.sourceTable != nil {
er.evalDefaultExprForTable(v, er.sourceTable)
} else {
er.err = errors.Errorf("Unsupported expr %T when source table not provided", v)
}
// TODO: Perhaps we don't need to transcode these back to generic integers/strings
case *ast.TrimDirectionExpr:
er.ctxStackAppend(&expression.Constant{
Value: types.NewIntDatum(int64(v.Direction)),
RetType: types.NewFieldType(mysql.TypeTiny),
}, types.EmptyName)
case *ast.TimeUnitExpr:
er.ctxStackAppend(&expression.Constant{
Value: types.NewStringDatum(v.Unit.String()),
RetType: types.NewFieldType(mysql.TypeVarchar),
}, types.EmptyName)
case *ast.GetFormatSelectorExpr:
er.ctxStackAppend(&expression.Constant{
Value: types.NewStringDatum(v.Selector.String()),
RetType: types.NewFieldType(mysql.TypeVarchar),
}, types.EmptyName)
case *ast.SetCollationExpr:
arg := er.ctxStack[len(er.ctxStack)-1]
if collate.NewCollationEnabled() {
var collInfo *charset.Collation
// TODO(bb7133): use charset.ValidCharsetAndCollation when its bug is fixed.
if collInfo, er.err = collate.GetCollationByName(v.Collate); er.err != nil {
break
}
chs := arg.GetType(er.sctx.GetEvalCtx()).GetCharset()
// if the field is json, the charset is always utf8mb4.
if arg.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeJSON {
chs = mysql.UTF8MB4Charset
}
if chs != "" && collInfo.CharsetName != chs {
er.err = charset.ErrCollationCharsetMismatch.GenWithStackByArgs(collInfo.Name, chs)
break
}
}
// SetCollationExpr sets the collation explicitly, even when the evaluation type of the expression is non-string.
if _, ok := arg.(*expression.Column); ok || arg.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeJSON {
if arg.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeEnum || arg.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeSet {
er.err = plannererrors.ErrNotSupportedYet.GenWithStackByArgs("use collate clause for enum or set")
break
}
// Wrap a cast here to avoid changing the original FieldType of the column expression.
exprType := arg.GetType(er.sctx.GetEvalCtx()).Clone()
// if arg type is json, we should cast it to longtext if there is collate clause.
if arg.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeJSON {
exprType = types.NewFieldType(mysql.TypeLongBlob)
exprType.SetCharset(mysql.UTF8MB4Charset)
}
exprType.SetCollate(v.Collate)
casted := expression.BuildCastFunction(er.sctx, arg, exprType)
arg = casted
er.ctxStackPop(1)
er.ctxStackAppend(casted, types.EmptyName)
} else {
// For constant and scalar function, we can set its collate directly.
arg.GetType(er.sctx.GetEvalCtx()).SetCollate(v.Collate)
}
er.ctxStack[len(er.ctxStack)-1].SetCoercibility(expression.CoercibilityExplicit)
er.ctxStack[len(er.ctxStack)-1].SetCharsetAndCollation(arg.GetType(er.sctx.GetEvalCtx()).GetCharset(), arg.GetType(er.sctx.GetEvalCtx()).GetCollate())
default:
er.err = errors.Errorf("UnknownType: %T", v)
return retNode, false
}
if er.err != nil {
return retNode, false
}
return originInNode, true
}
// newFunctionWithInit chooses which expression.NewFunctionImpl() will be used.
func (er *expressionRewriter) newFunctionWithInit(funcName string, retType *types.FieldType, init expression.ScalarFunctionCallBack, args ...expression.Expression) (ret expression.Expression, err error) {
if init != nil {
ret, err = expression.NewFunctionWithInit(er.sctx, funcName, retType, init, args...)
} else if er.disableFoldCounter > 0 {
ret, err = expression.NewFunctionBase(er.sctx, funcName, retType, args...)
} else if er.tryFoldCounter > 0 {
ret, err = expression.NewFunctionTryFold(er.sctx, funcName, retType, args...)
} else {
ret, err = expression.NewFunction(er.sctx, funcName, retType, args...)
}
if err != nil {
return
}
if scalarFunc, ok := ret.(*expression.ScalarFunction); ok {
if er.planCtx != nil && er.planCtx.builder != nil {
er.planCtx.builder.ctx.BuiltinFunctionUsageInc(scalarFunc.Function.PbCode().String())
}
}
return
}
// newFunction is being redirected to newFunctionWithInit.
func (er *expressionRewriter) newFunction(funcName string, retType *types.FieldType, args ...expression.Expression) (ret expression.Expression, err error) {
return er.newFunctionWithInit(funcName, retType, nil, args...)
}
func (*expressionRewriter) checkTimePrecision(ft *types.FieldType) error {
if ft.EvalType() == types.ETDuration && ft.GetDecimal() > types.MaxFsp {
return plannererrors.ErrTooBigPrecision.GenWithStackByArgs(ft.GetDecimal(), "CAST", types.MaxFsp)
}
return nil
}
func (er *expressionRewriter) useCache() bool {
return er.sctx.IsUseCache()
}
func (er *expressionRewriter) rewriteUserVariable(v *ast.VariableExpr) {
stkLen := len(er.ctxStack)
name := strings.ToLower(v.Name)
evalCtx := er.sctx.GetEvalCtx()
if v.Value != nil {
if !evalCtx.GetOptionalPropSet().Contains(exprctx.OptPropSessionVars) {
er.err = errors.Errorf("rewriting user variable requires '%s' in evalCtx", exprctx.OptPropSessionVars.String())
return
}
sessionVars, err := expropt.SessionVarsPropReader{}.GetSessionVars(evalCtx)
if err != nil {
er.err = err
return
}
intest.Assert(er.planCtx == nil || sessionVars == er.planCtx.builder.ctx.GetSessionVars())
tp := er.ctxStack[stkLen-1].GetType(er.sctx.GetEvalCtx())
er.ctxStack[stkLen-1], er.err = er.newFunction(ast.SetVar, tp,
expression.DatumToConstant(types.NewDatum(name), mysql.TypeString, 0),
er.ctxStack[stkLen-1])
er.ctxNameStk[stkLen-1] = types.EmptyName
// Store the field type of the variable into SessionVars.UserVarTypes.
// Normally we can infer the type from SessionVars.User, but we need SessionVars.UserVarTypes when
// GetVar has not been executed to fill the SessionVars.Users.
sessionVars.SetUserVarType(name, tp)
return
}
tp, ok := evalCtx.GetUserVarsReader().GetUserVarType(name)
if !ok {
tp = types.NewFieldType(mysql.TypeVarString)
tp.SetFlen(mysql.MaxFieldVarCharLength)
}
f, err := er.newFunction(ast.GetVar, tp, expression.DatumToConstant(types.NewStringDatum(name), mysql.TypeString, 0))
if err != nil {
er.err = err
return
}
f.SetCoercibility(expression.CoercibilityImplicit)
er.ctxStackAppend(f, types.EmptyName)
}
func (er *expressionRewriter) rewriteSystemVariable(planCtx *exprRewriterPlanCtx, v *ast.VariableExpr) {
name := strings.ToLower(v.Name)
sessionVars := planCtx.builder.ctx.GetSessionVars()
sysVar := variable.GetSysVar(name)
if sysVar == nil {
er.err = variable.ErrUnknownSystemVar.FastGenByArgs(name)
if err := variable.CheckSysVarIsRemoved(name); err != nil {
// Removed vars still return an error, but we customize it from
// "unknown" to an explanation of why it is not supported.
// This is important so users at least know they had the name correct.
er.err = err
}
return
}
if sysVar.IsNoop && !vardef.EnableNoopVariables.Load() {
// The variable does nothing, append a warning to the statement output.
sessionVars.StmtCtx.AppendWarning(plannererrors.ErrGettingNoopVariable.FastGenByArgs(sysVar.Name))
}
if sem.IsEnabled() && sem.IsInvisibleSysVar(sysVar.Name) {
err := plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("RESTRICTED_VARIABLES_ADMIN")
planCtx.builder.visitInfo = appendDynamicVisitInfo(planCtx.builder.visitInfo, []string{"RESTRICTED_VARIABLES_ADMIN"}, false, err)
}
if v.ExplicitScope && !sysVar.HasNoneScope() {
if v.IsGlobal && !(sysVar.HasGlobalScope() || sysVar.HasInstanceScope()) {
er.err = variable.ErrIncorrectScope.GenWithStackByArgs(name, "SESSION")
return
}
if v.IsInstance && !sysVar.HasInstanceScope() {
er.err = variable.ErrIncorrectScope.GenWithStackByArgs(name, "SESSION or GLOBAL")
return
}
if !v.IsGlobal && !v.IsInstance {
if !sysVar.HasSessionScope() {
er.err = variable.ErrIncorrectScope.GenWithStackByArgs(name, "GLOBAL")
return
}
if sysVar.InternalSessionVariable {
er.err = variable.ErrUnknownSystemVar.GenWithStackByArgs(name)
return
}
}
}
var val string
var err error
if sysVar.HasNoneScope() {
val = sysVar.Value
} else if v.IsGlobal || v.IsInstance {
val, err = sessionVars.GetGlobalSystemVar(er.ctx, name)
} else {
val, err = sessionVars.GetSessionOrGlobalSystemVar(er.ctx, name)
}
if err != nil {
er.err = err
return
}
nativeVal, nativeType, nativeFlag := sysVar.GetNativeValType(val)
e := expression.DatumToConstant(nativeVal, nativeType, nativeFlag)
switch nativeType {
case mysql.TypeVarString:
charset, _ := sessionVars.GetSystemVar(vardef.CharacterSetConnection)
e.GetType(er.sctx.GetEvalCtx()).SetCharset(charset)
collate, _ := sessionVars.GetSystemVar(vardef.CollationConnection)
e.GetType(er.sctx.GetEvalCtx()).SetCollate(collate)
case mysql.TypeLong, mysql.TypeLonglong:
e.GetType(er.sctx.GetEvalCtx()).SetCharset(charset.CharsetBin)
e.GetType(er.sctx.GetEvalCtx()).SetCollate(charset.CollationBin)
default:
er.err = errors.Errorf("Not supported type(%x) in GetNativeValType() function", nativeType)
return
}
er.ctxStackAppend(e, types.EmptyName)
}
func (er *expressionRewriter) unaryOpToExpression(v *ast.UnaryOperationExpr) {
stkLen := len(er.ctxStack)
var op string
switch v.Op {
case opcode.Plus:
// expression (+ a) is equal to a
return
case opcode.Minus:
op = ast.UnaryMinus
case opcode.BitNeg:
op = ast.BitNeg
case opcode.Not, opcode.Not2:
op = ast.UnaryNot
default:
er.err = errors.Errorf("Unknown Unary Op %T", v.Op)
return
}
if expression.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
return
}
er.ctxStack[stkLen-1], er.err = er.newFunction(op, &v.Type, er.ctxStack[stkLen-1])
er.ctxNameStk[stkLen-1] = types.EmptyName
}
func (er *expressionRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) {
stkLen := len(er.ctxStack)
var function expression.Expression
switch v.Op {
case opcode.EQ, opcode.NE, opcode.NullEQ, opcode.GT, opcode.GE, opcode.LT, opcode.LE:
function, er.err = er.constructBinaryOpFunction(er.ctxStack[stkLen-2], er.ctxStack[stkLen-1],
v.Op.String())
default:
lLen := expression.GetRowLen(er.ctxStack[stkLen-2])
rLen := expression.GetRowLen(er.ctxStack[stkLen-1])
if lLen != 1 || rLen != 1 {
er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
return
}
function, er.err = er.newFunction(v.Op.String(), types.NewFieldType(mysql.TypeUnspecified), er.ctxStack[stkLen-2:]...)
}
if er.err != nil {
return
}
er.ctxStackPop(2)
er.ctxStackAppend(function, types.EmptyName)
}
func (er *expressionRewriter) notToExpression(hasNot bool, op string, tp *types.FieldType,
args ...expression.Expression) expression.Expression {
opFunc, err := er.newFunction(op, tp, args...)
if err != nil {
er.err = err
return nil
}
if !hasNot {
return opFunc
}
opFunc, err = er.newFunction(ast.UnaryNot, tp, opFunc)
if err != nil {
er.err = err
return nil
}
return opFunc
}
func (er *expressionRewriter) isNullToExpression(v *ast.IsNullExpr) {
stkLen := len(er.ctxStack)
if expression.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
return
}
function := er.notToExpression(v.Not, ast.IsNull, &v.Type, er.ctxStack[stkLen-1])
er.ctxStackPop(1)
er.ctxStackAppend(function, types.EmptyName)
}
func (er *expressionRewriter) positionToScalarFunc(planCtx *exprRewriterPlanCtx, v *ast.PositionExpr) {
intest.AssertNotNil(planCtx)
pos := v.N
str := strconv.Itoa(pos)
if v.P != nil {
stkLen := len(er.ctxStack)
val := er.ctxStack[stkLen-1]
intNum, isNull, err := expression.GetIntFromConstant(er.sctx.GetEvalCtx(), val)
str = "?"
if err == nil {
if isNull {
return
}
pos = intNum
er.ctxStackPop(1)
}
er.err = err
}
if er.err == nil && pos > 0 && pos <= er.schema.Len() && !er.schema.Columns[pos-1].IsHidden {
er.ctxStackAppend(er.schema.Columns[pos-1], er.names[pos-1])
} else {
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(str, clauseMsg[planCtx.builder.curClause])
}
}
func (er *expressionRewriter) isTrueToScalarFunc(v *ast.IsTruthExpr) {
stkLen := len(er.ctxStack)
op := ast.IsTruthWithoutNull
if v.True == 0 {
op = ast.IsFalsity
}
if expression.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
return
}
function := er.notToExpression(v.Not, op, &v.Type, er.ctxStack[stkLen-1])
er.ctxStackPop(1)
er.ctxStackAppend(function, types.EmptyName)
}
// inToExpression converts in expression to a scalar function. The argument lLen means the length of in list.
// The argument not means if the expression is not in. The tp stands for the expression type, which is always bool.
// a in (b, c, d) will be rewritten as `(a = b) or (a = c) or (a = d)`.
func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.FieldType) {
stkLen := len(er.ctxStack)
l := expression.GetRowLen(er.ctxStack[stkLen-lLen-1])
for i := range lLen {
if l != expression.GetRowLen(er.ctxStack[stkLen-lLen+i]) {
er.err = expression.ErrOperandColumns.GenWithStackByArgs(l)
return
}
}
args := er.ctxStack[stkLen-lLen-1:]
leftFt := args[0].GetType(er.sctx.GetEvalCtx())
leftEt, leftIsNull := leftFt.EvalType(), leftFt.GetType() == mysql.TypeNull
if leftIsNull {
er.ctxStackPop(lLen + 1)
er.ctxStackAppend(expression.NewNull(), types.EmptyName)
return
}
if leftEt == types.ETInt {
for i := 1; i < len(args); i++ {
if c, ok := args[i].(*expression.Constant); ok {
var isExceptional bool
if expression.MaybeOverOptimized4PlanCache(er.sctx, c) {
if c.GetType(er.sctx.GetEvalCtx()).EvalType() == types.ETInt {
continue // no need to refine it
}
er.sctx.SetSkipPlanCache(fmt.Sprintf("'%v' may be converted to INT", c.StringWithCtx(er.sctx.GetEvalCtx(), errors.RedactLogDisable)))
if err := expression.RemoveMutableConst(er.sctx, c); err != nil {
er.err = err
return
}
}
args[i], isExceptional = expression.RefineComparedConstant(er.sctx, *leftFt, c, opcode.EQ)
if isExceptional {
args[i] = c
}
}
}
}
allSameType := true
for _, arg := range args[1:] {
if arg.GetType(er.sctx.GetEvalCtx()).GetType() != mysql.TypeNull && expression.GetAccurateCmpType(er.sctx.GetEvalCtx(), args[0], arg) != leftEt {
allSameType = false
break
}
}
var function expression.Expression
if allSameType && l == 1 && lLen > 1 {
function = er.notToExpression(not, ast.In, tp, er.ctxStack[stkLen-lLen-1:]...)
} else {
// If we rewrite IN to EQ, we need to decide what's the collation EQ uses.
coll := er.deriveCollationForIn(l, lLen, args)
if er.err != nil {
return
}
er.castCollationForIn(l, lLen, stkLen, coll)
eqFunctions := make([]expression.Expression, 0, lLen)
for i := stkLen - lLen; i < stkLen; i++ {
expr, err := er.constructBinaryOpFunction(args[0], er.ctxStack[i], ast.EQ)
if err != nil {
er.err = err
return
}
eqFunctions = append(eqFunctions, expr)
}
function = expression.ComposeDNFCondition(er.sctx, eqFunctions...)
if not {
var err error
function, err = er.newFunction(ast.UnaryNot, tp, function)
if err != nil {
er.err = err
return
}
}
}
er.ctxStackPop(lLen + 1)
er.ctxStackAppend(function, types.EmptyName)
}
// deriveCollationForIn derives collation for in expression.
// We don't handle the cases if the element is a tuple, such as (a, b, c) in ((x1, y1, z1), (x2, y2, z2)).
func (er *expressionRewriter) deriveCollationForIn(colLen int, _ int, args []expression.Expression) *expression.ExprCollation {
if colLen == 1 {
// a in (x, y, z) => coll[0]
coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...)
er.err = err
if er.err != nil {
return nil
}
return coll2
}
return nil
}
// castCollationForIn casts collation info for arguments in the `in clause` to make sure the used collation is correct after we
// rewrite it to equal expression.
func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen int, coll *expression.ExprCollation) {
// We don't handle the cases if the element is a tuple, such as (a, b, c) in ((x1, y1, z1), (x2, y2, z2)).
if colLen != 1 {
return
}
if !collate.NewCollationEnabled() {
// See https://github.com/pingcap/tidb/issues/52772
// This function will apply CoercibilityExplicit to the casted expression, but some checks(during ColumnSubstituteImpl) is missed when the new
// collation is disabled, then lead to panic.
// To work around this issue, we can skip the function, it should be good since the collation is disabled.
return
}
for i := stkLen - elemCnt; i < stkLen; i++ {
// todo: consider refining the code and reusing expression.BuildCollationFunction here
if er.ctxStack[i].GetType(er.sctx.GetEvalCtx()).EvalType() == types.ETString {
rowFunc, ok := er.ctxStack[i].(*expression.ScalarFunction)
if ok && rowFunc.FuncName.String() == ast.RowFunc {
continue
}
// Don't convert it if it's charset is binary. So that we don't convert 0x12 to a string.
if er.ctxStack[i].GetType(er.sctx.GetEvalCtx()).GetCollate() == coll.Collation {
continue
}
tp := er.ctxStack[i].GetType(er.sctx.GetEvalCtx()).Clone()
if er.ctxStack[i].GetType(er.sctx.GetEvalCtx()).Hybrid() {
if !(expression.GetAccurateCmpType(er.sctx.GetEvalCtx(), er.ctxStack[stkLen-elemCnt-1], er.ctxStack[i]) == types.ETString) {
continue
}
tp = types.NewFieldType(mysql.TypeVarString)
} else if coll.Charset == charset.CharsetBin {
// When cast character string to binary string, if we still use fixed length representation,
// then 0 padding will be used, which can affect later execution.
// e.g. https://github.com/pingcap/tidb/pull/35053#pullrequestreview-1008757770 gives an unexpected case.
// On the other hand, we can not directly return origin expr back,
// since we need binary collation to do string comparison later.
// Here we use VarString type of cast, i.e `cast(a as binary)`, to avoid this problem.
tp.SetType(mysql.TypeVarString)
}
tp.SetCharset(coll.Charset)
tp.SetCollate(coll.Collation)
er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp)
er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit)
}
}
}
func (er *expressionRewriter) caseToExpression(v *ast.CaseExpr) {
stkLen := len(er.ctxStack)
argsLen := 2 * len(v.WhenClauses)
if v.ElseClause != nil {
argsLen++
}
er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[stkLen-argsLen:]...)
if er.err != nil {
return
}
// value -> ctxStack[stkLen-argsLen-1]
// when clause(condition, result) -> ctxStack[stkLen-argsLen:stkLen-1];
// else clause -> ctxStack[stkLen-1]
var args []expression.Expression
if v.Value != nil {
// args: eq scalar func(args: value, condition1), result1,
// eq scalar func(args: value, condition2), result2,
// ...
// else clause
value := er.ctxStack[stkLen-argsLen-1]
args = make([]expression.Expression, 0, argsLen)
for i := stkLen - argsLen; i < stkLen-1; i += 2 {
arg, err := er.newFunction(ast.EQ, types.NewFieldType(mysql.TypeTiny), value, er.ctxStack[i])
if err != nil {
er.err = err
return
}
args = append(args, arg)
args = append(args, er.ctxStack[i+1])
}
if v.ElseClause != nil {
args = append(args, er.ctxStack[stkLen-1])
}
argsLen++ // for trimming the value element later
} else {
// args: condition1, result1,
// condition2, result2,
// ...
// else clause
args = er.ctxStack[stkLen-argsLen:]
}
function, err := er.newFunction(ast.Case, &v.Type, args...)
if err != nil {
er.err = err
return
}
er.ctxStackPop(argsLen)
er.ctxStackAppend(function, types.EmptyName)
}
func (er *expressionRewriter) patternLikeOrIlikeToExpression(v *ast.PatternLikeOrIlikeExpr) {
l := len(er.ctxStack)
er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[l-2:]...)
if er.err != nil {
return
}
char, col := er.sctx.GetCharsetInfo()
var function expression.Expression
fieldType := &types.FieldType{}
isPatternExactMatch := false
// Treat predicate 'like' or 'ilike' the same way as predicate '=' when it is an exact match and new collation is not enabled.
if patExpression, ok := er.ctxStack[l-1].(*expression.Constant); ok && !collate.NewCollationEnabled() {
patString, isNull, err := patExpression.EvalString(er.sctx.GetEvalCtx(), chunk.Row{})
if err != nil {
er.err = err
return
}
if !isNull {
patValue, patTypes := stringutil.CompilePattern(patString, v.Escape)
if stringutil.IsExactMatch(patTypes) && er.ctxStack[l-2].GetType(er.sctx.GetEvalCtx()).EvalType() == types.ETString {
op := ast.EQ
if v.Not {
op = ast.NE
}
types.DefaultTypeForValue(string(patValue), fieldType, char, col)
function, er.err = er.constructBinaryOpFunction(er.ctxStack[l-2],
&expression.Constant{Value: types.NewStringDatum(string(patValue)), RetType: fieldType},
op)
isPatternExactMatch = true
}
}
}
if !isPatternExactMatch {
funcName := ast.Like
if !v.IsLike {
funcName = ast.Ilike
}
types.DefaultTypeForValue(int(v.Escape), fieldType, char, col)
function = er.notToExpression(v.Not, funcName, &v.Type,
er.ctxStack[l-2], er.ctxStack[l-1], &expression.Constant{Value: types.NewIntDatum(int64(v.Escape)), RetType: fieldType})
}
er.ctxStackPop(2)
er.ctxStackAppend(function, types.EmptyName)
}
func (er *expressionRewriter) regexpToScalarFunc(v *ast.PatternRegexpExpr) {
l := len(er.ctxStack)
er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[l-2:]...)
if er.err != nil {
return
}
function := er.notToExpression(v.Not, ast.Regexp, &v.Type, er.ctxStack[l-2], er.ctxStack[l-1])
er.ctxStackPop(2)
er.ctxStackAppend(function, types.EmptyName)
}
func (er *expressionRewriter) rowToScalarFunc(v *ast.RowExpr) {
stkLen := len(er.ctxStack)
length := len(v.Values)
rows := make([]expression.Expression, 0, length)
for i := stkLen - length; i < stkLen; i++ {
rows = append(rows, er.ctxStack[i])
}
er.ctxStackPop(length)
function, err := er.newFunction(ast.RowFunc, rows[0].GetType(er.sctx.GetEvalCtx()), rows...)
if err != nil {
er.err = err
return
}
er.ctxStackAppend(function, types.EmptyName)
}
func (er *expressionRewriter) wrapExpWithCast() (expr, lexp, rexp expression.Expression) {
stkLen := len(er.ctxStack)
expr, lexp, rexp = er.ctxStack[stkLen-3], er.ctxStack[stkLen-2], er.ctxStack[stkLen-1]
var castFunc func(expression.BuildContext, expression.Expression) expression.Expression
switch expression.ResolveType4Between(er.sctx.GetEvalCtx(), [3]expression.Expression{expr, lexp, rexp}) {
case types.ETInt:
expr = expression.WrapWithCastAsInt(er.sctx, expr, nil)
lexp = expression.WrapWithCastAsInt(er.sctx, lexp, nil)
rexp = expression.WrapWithCastAsInt(er.sctx, rexp, nil)
return
case types.ETReal:
castFunc = expression.WrapWithCastAsReal
case types.ETDecimal:
castFunc = expression.WrapWithCastAsDecimal
case types.ETString:
castFunc = func(ctx expression.BuildContext, e expression.Expression) expression.Expression {
// string kind expression do not need cast
if e.GetType(er.sctx.GetEvalCtx()).EvalType().IsStringKind() {
return e
}
return expression.WrapWithCastAsString(ctx, e)
}
case types.ETDuration:
expr = expression.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(mysql.TypeDuration))
lexp = expression.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(mysql.TypeDuration))
rexp = expression.WrapWithCastAsTime(er.sctx, rexp, types.NewFieldType(mysql.TypeDuration))
return
case types.ETDatetime:
expr = expression.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(mysql.TypeDatetime))
lexp = expression.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(mysql.TypeDatetime))
rexp = expression.WrapWithCastAsTime(er.sctx, rexp, types.NewFieldType(mysql.TypeDatetime))
return
default:
return
}
expr = castFunc(er.sctx, expr)
lexp = castFunc(er.sctx, lexp)
rexp = castFunc(er.sctx, rexp)
return
}
func (er *expressionRewriter) betweenToExpression(v *ast.BetweenExpr) {
stkLen := len(er.ctxStack)
er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[stkLen-3:]...)
if er.err != nil {
return
}
expr, lexp, rexp := er.wrapExpWithCast()
coll, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "BETWEEN", types.ETInt, expr, lexp, rexp)
er.err = err
if er.err != nil {
return
}
// Handle enum or set. We need to know their real type to decide whether to cast them.
lt := expression.GetAccurateCmpType(er.sctx.GetEvalCtx(), expr, lexp)
rt := expression.GetAccurateCmpType(er.sctx.GetEvalCtx(), expr, rexp)
enumOrSetRealTypeIsStr := lt != types.ETInt && rt != types.ETInt
expr = expression.BuildCastCollationFunction(er.sctx, expr, coll, enumOrSetRealTypeIsStr)
lexp = expression.BuildCastCollationFunction(er.sctx, lexp, coll, enumOrSetRealTypeIsStr)
rexp = expression.BuildCastCollationFunction(er.sctx, rexp, coll, enumOrSetRealTypeIsStr)
var l, r expression.Expression
l, er.err = expression.NewFunction(er.sctx, ast.GE, &v.Type, expr, lexp)
if er.err != nil {
return
}
r, er.err = expression.NewFunction(er.sctx, ast.LE, &v.Type, expr, rexp)
if er.err != nil {
return
}
function, err := er.newFunction(ast.LogicAnd, &v.Type, l, r)
if err != nil {
er.err = err
return
}
if v.Not {
function, err = er.newFunction(ast.UnaryNot, &v.Type, function)
if err != nil {
er.err = err
return
}
}
er.ctxStackPop(3)
er.ctxStackAppend(function, types.EmptyName)
}
// rewriteFuncCall handles a FuncCallExpr and generates a customized function.
// It should return true if for the given FuncCallExpr a rewrite is performed so that original behavior is skipped.
// Otherwise it should return false to indicate (the caller) that original behavior needs to be performed.
func (er *expressionRewriter) rewriteFuncCall(v *ast.FuncCallExpr) bool {
switch v.FnName.L {
// when column is not null, ifnull on such column can be optimized to a cast.
case ast.Ifnull:
if len(v.Args) != 2 {
er.err = expression.ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O)
return true
}
stackLen := len(er.ctxStack)
lhs := er.ctxStack[stackLen-2]
rhs := er.ctxStack[stackLen-1]
col, isColumn := lhs.(*expression.Column)
var isEnumSet bool
if lhs.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeEnum || lhs.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeSet {
isEnumSet = true
}
// if expr1 is a column with not null flag, then we can optimize it as a cast.
if isColumn && !isEnumSet && mysql.HasNotNullFlag(col.RetType.GetFlag()) {
retTp, err := expression.InferType4ControlFuncs(er.sctx, ast.Ifnull, lhs, rhs)
if err != nil {
er.err = err
return true
}
retTp.AddFlag((lhs.GetType(er.sctx.GetEvalCtx()).GetFlag() & mysql.NotNullFlag) | (rhs.GetType(er.sctx.GetEvalCtx()).GetFlag() & mysql.NotNullFlag))
if lhs.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeNull && rhs.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeNull {
retTp.SetType(mysql.TypeNull)
retTp.SetFlen(0)
retTp.SetDecimal(0)
types.SetBinChsClnFlag(retTp)
}
er.ctxStackPop(len(v.Args))
casted := expression.BuildCastFunction(er.sctx, lhs, retTp)
er.ctxStackAppend(casted, types.EmptyName)
return true
}
return false
case ast.Nullif:
if len(v.Args) != 2 {
er.err = expression.ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O)
return true
}
stackLen := len(er.ctxStack)
param1 := er.ctxStack[stackLen-2]
param2 := er.ctxStack[stackLen-1]
// param1 = param2
funcCompare, err := er.constructBinaryOpFunction(param1, param2, ast.EQ)
if err != nil {
er.err = err
return true
}
// NULL
nullTp := types.NewFieldType(mysql.TypeNull)
flen, decimal := mysql.GetDefaultFieldLengthAndDecimal(mysql.TypeNull)
nullTp.SetFlen(flen)
nullTp.SetDecimal(decimal)
paramNull := &expression.Constant{
Value: types.NewDatum(nil),
RetType: nullTp,
}
// if(param1 = param2, NULL, param1)
funcIf, err := er.newFunction(ast.If, &v.Type, funcCompare, paramNull, param1)
if err != nil {
er.err = err
return true
}
er.ctxStackPop(len(v.Args))
er.ctxStackAppend(funcIf, types.EmptyName)
return true
default:
return false
}
}
func (er *expressionRewriter) funcCallToExpressionWithPlanCtx(planCtx *exprRewriterPlanCtx, v *ast.FuncCallExpr) {
stackLen := len(er.ctxStack)
args := er.ctxStack[stackLen-len(v.Args):]
er.err = expression.CheckArgsNotMultiColumnRow(args...)
if er.err != nil {
return
}
var function expression.Expression
er.ctxStackPop(len(v.Args))
switch v.FnName.L {
case ast.Grouping:
// grouping function should fetch the underlying grouping-sets meta and rewrite the args here.
// eg: grouping(a) actually is try to find in which grouping-set that the column 'a' is remained,
// collecting those gid as a collection and filling it into the grouping function meta. Besides,
// the first arg of grouping function should be rewritten as gid column defined/passed by Expand
// from the bottom up.
intest.AssertNotNil(planCtx)
if planCtx.rollExpand == nil {
er.err = plannererrors.ErrInvalidGroupFuncUse
er.ctxStackAppend(nil, types.EmptyName)
} else {
// whether there is some duplicate grouping sets, gpos is only be used in shuffle keys and group keys
// rather than grouping function.
// eg: rollup(a,a,b), the decided grouping sets are {a,a,b},{a,a,null},{a,null,null},{null,null,null}
// for the second and third grouping set: {a,a,null} and {a,null,null}, a here is the col ref of original
// column `a`. So from the static layer, this two grouping set are equivalent, we don't need to copy col
// `a double times at the every beginning and resort to gpos to distinguish them.
// {col-a, col-b, gid, gpos}
// {a, b, 0, 1}, {a, null, 1, 2}, {a, null, 1, 3}, {null, null, 2, 4}
// grouping function still only need to care about gid is enough, gpos what group and shuffle keys cared.
if len(args) > 64 {
er.err = plannererrors.ErrInvalidNumberOfArgs.GenWithStackByArgs("GROUPING", 64)
er.ctxStackAppend(nil, types.EmptyName)
return
}
// resolve grouping args in group by items or not.
resolvedCols, err := planCtx.rollExpand.ResolveGroupingFuncArgsInGroupBy(args)
if err != nil {
er.err = err
er.ctxStackAppend(nil, types.EmptyName)
return
}
newArg := planCtx.rollExpand.GID.Clone()
init := func(groupingFunc *expression.ScalarFunction) (*expression.ScalarFunction, error) {
err = groupingFunc.Function.(*expression.BuiltinGroupingImplSig).SetMetadata(planCtx.rollExpand.GroupingMode, planCtx.rollExpand.GenerateGroupingMarks(resolvedCols))
return groupingFunc, err
}
function, er.err = er.newFunctionWithInit(v.FnName.L, &v.Type, init, newArg)
er.ctxStackAppend(function, types.EmptyName)
}
default:
er.err = errors.Errorf("invalid function: %s", v.FnName.L)
er.ctxStackAppend(nil, types.EmptyName)
}
}
func (er *expressionRewriter) funcCallToExpression(v *ast.FuncCallExpr) {
stackLen := len(er.ctxStack)
args := er.ctxStack[stackLen-len(v.Args):]
er.err = expression.CheckArgsNotMultiColumnRow(args...)
if er.err != nil {
return
}
if er.rewriteFuncCall(v) {
return
}
var function expression.Expression
er.ctxStackPop(len(v.Args))
if ok := expression.IsDeferredFunctions(er.sctx, v.FnName.L); er.useCache() && ok {
// When the expression is unix_timestamp and the number of argument is not zero,
// we deal with it as normal expression.
if v.FnName.L == ast.UnixTimestamp && len(v.Args) != 0 {
function, er.err = er.newFunction(v.FnName.L, &v.Type, args...)
er.ctxStackAppend(function, types.EmptyName)
} else {
function, er.err = expression.NewFunctionBase(er.sctx, v.FnName.L, &v.Type, args...)
c := &expression.Constant{Value: types.NewDatum(nil), RetType: function.GetType(er.sctx.GetEvalCtx()).Clone(), DeferredExpr: function}
er.ctxStackAppend(c, types.EmptyName)
}
} else {
function, er.err = er.newFunction(v.FnName.L, &v.Type, args...)
er.ctxStackAppend(function, types.EmptyName)
}
}
// Now TableName in expression only used by sequence function like nextval(seq).
// The function arg should be evaluated as a table name rather than normal column name like mysql does.
func (er *expressionRewriter) toTable(v *ast.TableName) {
fullName := v.Name.L
if len(v.Schema.L) != 0 {
fullName = v.Schema.L + "." + fullName
}
val := &expression.Constant{
Value: types.NewDatum(fullName),
RetType: types.NewFieldType(mysql.TypeString),
}
er.ctxStackAppend(val, types.EmptyName)
}
func (er *expressionRewriter) toParamMarker(v *driver.ParamMarkerExpr) {
var value *expression.Constant
value, er.err = expression.ParamMarkerExpression(er.sctx, v, false)
if er.err != nil {
return
}
initConstantRepertoire(er.sctx.GetEvalCtx(), value)
er.adjustUTF8MB4Collation(value.RetType)
if er.err != nil {
return
}
er.ctxStackAppend(value, types.EmptyName)
}
func (er *expressionRewriter) clause() clauseCode {
if er.planCtx != nil {
return er.planCtx.builder.curClause
}
return expressionClause
}
func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
idx, err := expression.FindFieldName(er.names, v)
if err != nil {
er.err = plannererrors.ErrAmbiguous.GenWithStackByArgs(v.Name, clauseMsg[fieldList])
return
}
if idx >= 0 {
column := er.schema.Columns[idx]
if column.IsHidden {
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name, clauseMsg[er.clause()])
return
}
er.ctxStackAppend(column, er.names[idx])
return
} else if er.planCtx == nil && er.sourceTable != nil &&
(v.Table.L == "" || er.sourceTable.Name.L == v.Table.L) {
colInfo := er.sourceTable.FindPublicColumnByName(v.Name.L)
if colInfo == nil || colInfo.Hidden {
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name, clauseMsg[er.clause()])
return
}
er.ctxStackAppend(&expression.Column{
RetType: &colInfo.FieldType,
ID: colInfo.ID,
UniqueID: colInfo.ID,
OrigName: fmt.Sprintf("%s.%s", er.sourceTable.Name.L, colInfo.Name.L),
}, &types.FieldName{ColName: v.Name})
return
}
planCtx := er.planCtx
if planCtx == nil {
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.String(), clauseMsg[er.clause()])
return
}
col, name, err := findFieldNameFromNaturalUsingJoin(planCtx.plan, v)
if err != nil {
er.err = err
return
} else if col != nil {
er.ctxStackAppend(col, name)
return
}
for i := len(planCtx.builder.outerSchemas) - 1; i >= 0; i-- {
outerSchema, outerName := planCtx.builder.outerSchemas[i], planCtx.builder.outerNames[i]
idx, err = expression.FindFieldName(outerName, v)
if idx >= 0 {
column := outerSchema.Columns[idx]
er.ctxStackAppend(&expression.CorrelatedColumn{Column: *column, Data: new(types.Datum)}, outerName[idx])
return
}
if err != nil {
er.err = plannererrors.ErrAmbiguous.GenWithStackByArgs(v.Name, clauseMsg[fieldList])
return
}
}
if _, ok := planCtx.plan.(*logicalop.LogicalUnionAll); ok && v.Table.O != "" {
er.err = plannererrors.ErrTablenameNotAllowedHere.GenWithStackByArgs(v.Table.O, "SELECT", clauseMsg[planCtx.builder.curClause])
return
}
if planCtx.builder.curClause == globalOrderByClause {
planCtx.builder.curClause = orderByClause
}
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.String(), clauseMsg[planCtx.builder.curClause])
}
func findFieldNameFromNaturalUsingJoin(p base.LogicalPlan, v *ast.ColumnName) (col *expression.Column, name *types.FieldName, err error) {
switch x := p.(type) {
case *logicalop.LogicalLimit, *logicalop.LogicalSelection, *logicalop.LogicalTopN, *logicalop.LogicalSort, *logicalop.LogicalMaxOneRow:
return findFieldNameFromNaturalUsingJoin(p.Children()[0], v)
case *logicalop.LogicalJoin:
if x.FullSchema != nil {
idx, err := expression.FindFieldName(x.FullNames, v)
if err != nil {
return nil, nil, err
}
if idx >= 0 {
return x.FullSchema.Columns[idx], x.FullNames[idx], nil
}
}
}
return nil, nil, nil
}
func (er *expressionRewriter) evalDefaultExprForTable(v *ast.DefaultExpr, tbl *model.TableInfo) {
idx, err := expression.FindFieldName(er.names, v.Name)
if err != nil {
er.err = err
return
}
if idx < 0 {
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), "field list")
return
}
name := er.names[idx]
er.evalFieldDefaultValue(name, tbl)
}
func (er *expressionRewriter) evalDefaultExprWithPlanCtx(planCtx *exprRewriterPlanCtx, v *ast.DefaultExpr) {
intest.AssertNotNil(planCtx)
var name *types.FieldName
// Here we will find the corresponding column for default function. At the same time, we need to consider the issue
// of subquery and name space.
// For example, we have two tables t1(a int default 1, b int) and t2(a int default -1, c int). Consider the following SQL:
// select a from t1 where a > (select default(a) from t2)
// Refer to the behavior of MySQL, we need to find column a in table t2. If table t2 does not have column a, then find it
// in table t1. If there are none, return an error message.
// Based on the above description, we need to look in er.b.allNames from back to front.
for i := len(planCtx.builder.allNames) - 1; i >= 0; i-- {
idx, err := expression.FindFieldName(planCtx.builder.allNames[i], v.Name)
if err != nil {
er.err = err
return
}
if idx >= 0 {
name = planCtx.builder.allNames[i][idx]
break
}
}
if name == nil {
idx, err := expression.FindFieldName(er.names, v.Name)
if err != nil {
er.err = err
return
}
if idx < 0 {
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), "field list")
return
}
name = er.names[idx]
}
dbName := name.DBName
if dbName.O == "" {
// if database name is not specified, use current database name
dbName = ast.NewCIStr(planCtx.builder.ctx.GetSessionVars().CurrentDB)
}
if name.OrigTblName.O == "" {
// column is evaluated by some expressions, for example:
// `select default(c) from (select (a+1) as c from t) as t0`
// in such case, a 'no default' error is returned
er.err = table.ErrNoDefaultValue.GenWithStackByArgs(name.ColName)
return
}
var tbl table.Table
tbl, er.err = planCtx.builder.is.TableByName(context.Background(), dbName, name.OrigTblName)
if er.err != nil {
return
}
er.evalFieldDefaultValue(name, tbl.Meta())
}
func (er *expressionRewriter) evalFieldDefaultValue(field *types.FieldName, tblInfo *model.TableInfo) {
colName := field.OrigColName.L
if colName == "" {
// in some cases, OrigColName is empty, use ColName instead
colName = field.ColName.L
}
col := tblInfo.FindPublicColumnByName(colName)
if col == nil {
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(colName, "field_list")
return
}
isCurrentTimestamp := hasCurrentDatetimeDefault(col)
var val *expression.Constant
switch {
case isCurrentTimestamp && (col.GetType() == mysql.TypeDatetime || col.GetType() == mysql.TypeTimestamp):
t, err := expression.GetTimeValue(er.sctx, ast.CurrentTimestamp, col.GetType(), col.GetDecimal(), nil)
if err != nil {
return
}
val = &expression.Constant{
Value: t,
RetType: types.NewFieldType(col.GetType()),
}
default:
// for other columns, just use what it is
d, err := table.GetColDefaultValue(er.sctx, col)
if err != nil {
er.err = err
return
}
val = &expression.Constant{
Value: d,
RetType: col.FieldType.Clone(),
}
}
if er.err != nil {
return
}
er.ctxStackAppend(val, types.EmptyName)
}
// hasCurrentDatetimeDefault checks if column has current_timestamp default value
func hasCurrentDatetimeDefault(col *model.ColumnInfo) bool {
x, ok := col.DefaultValue.(string)
if !ok {
return false
}
return strings.ToLower(x) == ast.CurrentTimestamp
}
// hasLimit checks if the plan already contains a LIMIT operator
func hasLimit(plan base.LogicalPlan) bool {
if plan == nil {
return false
}
// Check if this is a LogicalLimit
if _, ok := plan.(*logicalop.LogicalLimit); ok {
return true
}
// Recursively check children
for _, child := range plan.Children() {
if hasLimit(child) {
return true
}
}
return false
}