2758 lines
102 KiB
Go
2758 lines
102 KiB
Go
// Copyright 2016 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package core
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/pingcap/errors"
|
|
"github.com/pingcap/tidb/pkg/expression"
|
|
"github.com/pingcap/tidb/pkg/expression/aggregation"
|
|
"github.com/pingcap/tidb/pkg/expression/exprctx"
|
|
"github.com/pingcap/tidb/pkg/expression/expropt"
|
|
"github.com/pingcap/tidb/pkg/infoschema"
|
|
"github.com/pingcap/tidb/pkg/meta/model"
|
|
"github.com/pingcap/tidb/pkg/parser/ast"
|
|
"github.com/pingcap/tidb/pkg/parser/charset"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/parser/opcode"
|
|
"github.com/pingcap/tidb/pkg/planner/core/base"
|
|
"github.com/pingcap/tidb/pkg/planner/core/operator/logicalop"
|
|
"github.com/pingcap/tidb/pkg/planner/core/operator/physicalop"
|
|
"github.com/pingcap/tidb/pkg/planner/core/rule"
|
|
"github.com/pingcap/tidb/pkg/planner/util/coreusage"
|
|
"github.com/pingcap/tidb/pkg/sessionctx/vardef"
|
|
"github.com/pingcap/tidb/pkg/sessionctx/variable"
|
|
"github.com/pingcap/tidb/pkg/table"
|
|
"github.com/pingcap/tidb/pkg/types"
|
|
driver "github.com/pingcap/tidb/pkg/types/parser_driver"
|
|
"github.com/pingcap/tidb/pkg/util/chunk"
|
|
"github.com/pingcap/tidb/pkg/util/collate"
|
|
"github.com/pingcap/tidb/pkg/util/dbterror/plannererrors"
|
|
"github.com/pingcap/tidb/pkg/util/hint"
|
|
"github.com/pingcap/tidb/pkg/util/intest"
|
|
sem "github.com/pingcap/tidb/pkg/util/sem/compat"
|
|
"github.com/pingcap/tidb/pkg/util/stringutil"
|
|
)
|
|
|
|
// EvalSubqueryFirstRow evaluates incorrelated subqueries once, and get first row.
|
|
var EvalSubqueryFirstRow func(ctx context.Context, p base.PhysicalPlan, is infoschema.InfoSchema, sctx base.PlanContext) (row []types.Datum, err error)
|
|
|
|
// evalAstExprWithPlanCtx evaluates ast expression with plan context.
|
|
// Different with expression.EvalSimpleAst, it uses planner context and is more powerful to build some special expressions
|
|
// like subquery, window function, etc.
|
|
func evalAstExprWithPlanCtx(sctx base.PlanContext, expr ast.ExprNode) (types.Datum, error) {
|
|
if val, ok := expr.(*driver.ValueExpr); ok {
|
|
return val.Datum, nil
|
|
}
|
|
newExpr, err := rewriteAstExprWithPlanCtx(sctx, expr, nil, nil, false)
|
|
if err != nil {
|
|
return types.Datum{}, err
|
|
}
|
|
return newExpr.Eval(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{})
|
|
}
|
|
|
|
// evalAstExpr evaluates ast expression directly.
|
|
func evalAstExpr(ctx expression.BuildContext, expr ast.ExprNode) (types.Datum, error) {
|
|
if val, ok := expr.(*driver.ValueExpr); ok {
|
|
return val.Datum, nil
|
|
}
|
|
newExpr, err := buildSimpleExpr(ctx, expr)
|
|
if err != nil {
|
|
return types.Datum{}, err
|
|
}
|
|
return newExpr.Eval(ctx.GetEvalCtx(), chunk.Row{})
|
|
}
|
|
|
|
// rewriteAstExprWithPlanCtx rewrites ast expression directly.
|
|
// Different with expression.BuildSimpleExpr, it uses planner context and is more powerful to build some special expressions
|
|
// like subquery, window function, etc.
|
|
func rewriteAstExprWithPlanCtx(sctx base.PlanContext, expr ast.ExprNode, schema *expression.Schema, names types.NameSlice, allowCastArray bool) (expression.Expression, error) {
|
|
var is infoschema.InfoSchema
|
|
// in tests, it may be null
|
|
if s, ok := sctx.GetInfoSchema().(infoschema.InfoSchema); ok {
|
|
is = s
|
|
}
|
|
b, savedBlockNames := NewPlanBuilder().Init(sctx, is, hint.NewQBHintHandler(nil))
|
|
b.allowBuildCastArray = allowCastArray
|
|
fakePlan := logicalop.LogicalTableDual{}.Init(sctx, 0)
|
|
if schema != nil {
|
|
fakePlan.SetSchema(schema)
|
|
fakePlan.SetOutputNames(names)
|
|
}
|
|
b.curClause = expressionClause
|
|
newExpr, _, err := b.rewrite(context.TODO(), expr, fakePlan, nil, true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sctx.GetSessionVars().PlannerSelectBlockAsName.Store(&savedBlockNames)
|
|
return newExpr, nil
|
|
}
|
|
|
|
func buildSimpleExpr(ctx expression.BuildContext, node ast.ExprNode, opts ...expression.BuildOption) (expression.Expression, error) {
|
|
intest.AssertNotNil(node)
|
|
if node == nil {
|
|
// This should never happen. Return error to make it easy to debug in case we have some unexpected bugs.
|
|
return nil, errors.New("expression node should be present")
|
|
}
|
|
|
|
var options expression.BuildOptions
|
|
for _, opt := range opts {
|
|
opt(&options)
|
|
}
|
|
|
|
if options.InputSchema == nil && len(options.InputNames) > 0 {
|
|
return nil, errors.New("InputSchema and InputNames should be specified at the same time")
|
|
}
|
|
|
|
if options.InputSchema != nil && len(options.InputSchema.Columns) != len(options.InputNames) {
|
|
return nil, errors.New("InputSchema and InputNames should be the same length")
|
|
}
|
|
|
|
// assert all input db names are the same if specified
|
|
intest.AssertFunc(func() bool {
|
|
if len(options.InputNames) == 0 {
|
|
return true
|
|
}
|
|
|
|
dbName := options.InputNames[0].DBName
|
|
if options.SourceTableDB.L != "" {
|
|
intest.Assert(dbName.L == options.SourceTableDB.L)
|
|
}
|
|
|
|
for _, name := range options.InputNames {
|
|
intest.Assert(name.DBName.L == dbName.L)
|
|
}
|
|
return true
|
|
})
|
|
|
|
rewriter := &expressionRewriter{
|
|
ctx: context.TODO(),
|
|
sctx: ctx,
|
|
schema: options.InputSchema,
|
|
names: options.InputNames,
|
|
sourceTable: options.SourceTable,
|
|
allowBuildCastArray: options.AllowCastArray,
|
|
asScalar: true,
|
|
}
|
|
|
|
if tbl := options.SourceTable; tbl != nil && rewriter.schema == nil {
|
|
cols, names, err := expression.ColumnInfos2ColumnsAndNames(ctx, options.SourceTableDB, tbl.Name, tbl.Cols(), tbl)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
intest.Assert(len(cols) == len(names))
|
|
rewriter.schema = expression.NewSchema(cols...)
|
|
rewriter.names = names
|
|
}
|
|
|
|
if rewriter.schema == nil {
|
|
rewriter.schema = expression.NewSchema()
|
|
}
|
|
|
|
expr, _, err := rewriteExprNode(rewriter, node, rewriter.asScalar)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if ft := options.TargetFieldType; ft != nil {
|
|
expr = expression.BuildCastFunction(ctx, expr, ft)
|
|
}
|
|
|
|
return expr, err
|
|
}
|
|
|
|
func (b *PlanBuilder) rewriteInsertOnDuplicateUpdate(ctx context.Context, exprNode ast.ExprNode, mockPlan base.LogicalPlan, insertPlan *physicalop.Insert) (expression.Expression, error) {
|
|
b.rewriterCounter++
|
|
defer func() { b.rewriterCounter-- }()
|
|
|
|
b.curClause = fieldList
|
|
rewriter := b.getExpressionRewriter(ctx, mockPlan)
|
|
// The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is
|
|
// not nil means certain previous procedure has not handled this error.
|
|
// Here we give us one more chance to make a correct behavior by handling
|
|
// this missed error.
|
|
if rewriter.err != nil {
|
|
return nil, rewriter.err
|
|
}
|
|
|
|
rewriter.planCtx.insertPlan = insertPlan
|
|
rewriter.asScalar = true
|
|
rewriter.allowBuildCastArray = b.allowBuildCastArray
|
|
|
|
expr, _, err := rewriteExprNode(rewriter, exprNode, true)
|
|
return expr, err
|
|
}
|
|
|
|
// rewrite function rewrites ast expr to expression.Expression.
|
|
// aggMapper maps ast.AggregateFuncExpr to the columns offset in p's output schema.
|
|
// asScalar means whether this expression must be treated as a scalar expression.
|
|
// And this function returns a result expression, a new plan that may have apply or semi-join.
|
|
func (b *PlanBuilder) rewrite(ctx context.Context, exprNode ast.ExprNode, p base.LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, asScalar bool) (expression.Expression, base.LogicalPlan, error) {
|
|
expr, resultPlan, err := b.rewriteWithPreprocess(ctx, exprNode, p, aggMapper, nil, asScalar, nil)
|
|
return expr, resultPlan, err
|
|
}
|
|
|
|
// rewriteWithPreprocess is for handling the situation that we need to adjust the input ast tree
|
|
// before really using its node in `expressionRewriter.Leave`. In that case, we first call
|
|
// er.preprocess(expr), which returns a new expr. Then we use the new expr in `Leave`.
|
|
func (b *PlanBuilder) rewriteWithPreprocess(
|
|
ctx context.Context,
|
|
exprNode ast.ExprNode,
|
|
p base.LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int,
|
|
windowMapper map[*ast.WindowFuncExpr]int,
|
|
asScalar bool,
|
|
preprocess func(ast.Node) ast.Node,
|
|
) (expression.Expression, base.LogicalPlan, error) {
|
|
b.rewriterCounter++
|
|
defer func() { b.rewriterCounter-- }()
|
|
|
|
rewriter := b.getExpressionRewriter(ctx, p)
|
|
// The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is
|
|
// not nil means certain previous procedure has not handled this error.
|
|
// Here we give us one more chance to make a correct behavior by handling
|
|
// this missed error.
|
|
if rewriter.err != nil {
|
|
return nil, nil, rewriter.err
|
|
}
|
|
|
|
rewriter.planCtx.aggrMap = aggMapper
|
|
rewriter.planCtx.windowMap = windowMapper
|
|
rewriter.asScalar = asScalar
|
|
rewriter.allowBuildCastArray = b.allowBuildCastArray
|
|
rewriter.preprocess = preprocess
|
|
|
|
expr, resultPlan, err := rewriteExprNode(rewriter, exprNode, asScalar)
|
|
return expr, resultPlan, err
|
|
}
|
|
|
|
func (b *PlanBuilder) getExpressionRewriter(ctx context.Context, p base.LogicalPlan) (rewriter *expressionRewriter) {
|
|
defer func() {
|
|
if p != nil {
|
|
if join, ok := p.(*logicalop.LogicalJoin); ok && join.FullSchema != nil {
|
|
rewriter.schema = join.FullSchema
|
|
rewriter.names = join.FullNames
|
|
} else {
|
|
rewriter.schema = p.Schema()
|
|
rewriter.names = p.OutputNames()
|
|
}
|
|
}
|
|
}()
|
|
|
|
if len(b.rewriterPool) < b.rewriterCounter {
|
|
rewriter = &expressionRewriter{
|
|
sctx: b.ctx.GetExprCtx(), ctx: ctx,
|
|
planCtx: &exprRewriterPlanCtx{plan: p, builder: b, curClause: b.curClause, rollExpand: b.currentBlockExpand},
|
|
}
|
|
b.rewriterPool = append(b.rewriterPool, rewriter)
|
|
return
|
|
}
|
|
|
|
rewriter = b.rewriterPool[b.rewriterCounter-1]
|
|
rewriter.asScalar = false
|
|
rewriter.preprocess = nil
|
|
rewriter.disableFoldCounter = 0
|
|
rewriter.tryFoldCounter = 0
|
|
rewriter.ctxStack = rewriter.ctxStack[:0]
|
|
rewriter.ctxNameStk = rewriter.ctxNameStk[:0]
|
|
rewriter.ctx = ctx
|
|
rewriter.err = nil
|
|
rewriter.planCtx.plan = p
|
|
rewriter.planCtx.curClause = b.curClause
|
|
rewriter.planCtx.aggrMap = nil
|
|
rewriter.planCtx.insertPlan = nil
|
|
rewriter.planCtx.rollExpand = b.currentBlockExpand
|
|
return
|
|
}
|
|
|
|
func rewriteExprNode(rewriter *expressionRewriter, exprNode ast.ExprNode, asScalar bool) (expression.Expression, base.LogicalPlan, error) {
|
|
planCtx := rewriter.planCtx
|
|
// sourceTable is only used to build simple expression with one table
|
|
// when planCtx is present, sourceTable should be nil.
|
|
intest.Assert(planCtx == nil || rewriter.sourceTable == nil)
|
|
if planCtx != nil && planCtx.plan != nil {
|
|
curColLen := planCtx.plan.Schema().Len()
|
|
defer func() {
|
|
names := planCtx.plan.OutputNames().Shallow()[:curColLen]
|
|
for i := curColLen; i < planCtx.plan.Schema().Len(); i++ {
|
|
names = append(names, types.EmptyName)
|
|
}
|
|
// After rewriting finished, only old columns are visible.
|
|
// e.g. select * from t where t.a in (select t1.a from t1);
|
|
// The output columns before we enter the subquery are the columns from t.
|
|
// But when we leave the subquery `t.a in (select t1.a from t1)`, we got a Apply operator
|
|
// and the output columns become [t.*, t1.*]. But t1.* is used only inside the subquery. If there's another filter
|
|
// which is also a subquery where t1 is involved. The name resolving will fail if we still expose the column from
|
|
// the previous subquery.
|
|
// So here we just reset the names to empty to avoid this situation.
|
|
// TODO: implement ScalarSubQuery and resolve it during optimizing. In building phase, we will not change the plan's structure.
|
|
planCtx.plan.SetOutputNames(names)
|
|
}()
|
|
}
|
|
exprNode.Accept(rewriter)
|
|
if rewriter.err != nil {
|
|
return nil, nil, errors.Trace(rewriter.err)
|
|
}
|
|
|
|
var plan base.LogicalPlan
|
|
if planCtx != nil {
|
|
plan = planCtx.plan
|
|
}
|
|
|
|
if !asScalar && len(rewriter.ctxStack) == 0 {
|
|
return nil, plan, nil
|
|
}
|
|
if len(rewriter.ctxStack) != 1 {
|
|
return nil, nil, errors.Errorf("context len %v is invalid", len(rewriter.ctxStack))
|
|
}
|
|
rewriter.err = expression.CheckArgsNotMultiColumnRow(rewriter.ctxStack[0])
|
|
if rewriter.err != nil {
|
|
return nil, nil, errors.Trace(rewriter.err)
|
|
}
|
|
return rewriter.ctxStack[0], plan, nil
|
|
}
|
|
|
|
type exprRewriterPlanCtx struct {
|
|
plan base.LogicalPlan
|
|
builder *PlanBuilder
|
|
|
|
// curClause tracks which part of the query is being processed
|
|
curClause clauseCode
|
|
|
|
aggrMap map[*ast.AggregateFuncExpr]int
|
|
windowMap map[*ast.WindowFuncExpr]int
|
|
|
|
// insertPlan is only used to rewrite the expressions inside the assignment
|
|
// of the "INSERT" statement.
|
|
insertPlan *physicalop.Insert
|
|
|
|
rollExpand *logicalop.LogicalExpand
|
|
}
|
|
|
|
type expressionRewriter struct {
|
|
ctxStack []expression.Expression
|
|
ctxNameStk []*types.FieldName
|
|
schema *expression.Schema
|
|
names []*types.FieldName
|
|
err error
|
|
|
|
sctx expression.BuildContext
|
|
ctx context.Context
|
|
|
|
// asScalar indicates the return value must be a scalar value.
|
|
// NOTE: This value can be changed during expression rewritten.
|
|
asScalar bool
|
|
// allowBuildCastArray indicates whether allow cast(... as ... array).
|
|
allowBuildCastArray bool
|
|
// sourceTable is only used to build simple expression without all columns from a single table
|
|
sourceTable *model.TableInfo
|
|
|
|
// preprocess is called for every ast.Node in Leave.
|
|
preprocess func(ast.Node) ast.Node
|
|
|
|
// disableFoldCounter controls fold-disabled scope. If > 0, rewriter will NOT do constant folding.
|
|
// Typically, during visiting AST, while entering the scope(disable), the counter will +1; while
|
|
// leaving the scope(enable again), the counter will -1.
|
|
// NOTE: This value can be changed during expression rewritten.
|
|
disableFoldCounter int
|
|
tryFoldCounter int
|
|
|
|
planCtx *exprRewriterPlanCtx
|
|
}
|
|
|
|
func (er *expressionRewriter) ctxStackLen() int {
|
|
return len(er.ctxStack)
|
|
}
|
|
|
|
func (er *expressionRewriter) ctxStackPop(num int) {
|
|
l := er.ctxStackLen()
|
|
er.ctxStack = er.ctxStack[:l-num]
|
|
er.ctxNameStk = er.ctxNameStk[:l-num]
|
|
}
|
|
|
|
func (er *expressionRewriter) ctxStackAppend(col expression.Expression, name *types.FieldName) {
|
|
er.ctxStack = append(er.ctxStack, col)
|
|
er.ctxNameStk = append(er.ctxNameStk, name)
|
|
}
|
|
|
|
// constructBinaryOpFunction converts binary operator functions
|
|
/*
|
|
The algorithm is as follows:
|
|
1. If the length of the two sides of the expression is 1, return l op r directly.
|
|
2. If the length of the two sides of the expression is not equal, return an error.
|
|
3. If the operator is EQ, NE, or NullEQ, converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2)
|
|
4. If the operator is not EQ, NE, or NullEQ,
|
|
converts (a0,a1,a2) op (b0,b1,b2) to (a0 > b0) or (a0 = b0 and a1 > b1) or (a0 = b0 and a1 = b1 and a2 op b2)
|
|
Especially, op is GE or LE, the prefix element will be converted to > or <.
|
|
converts (a0,a1,a2) >= (b0,b1,b2) to (a0 > b0) or (a0 = b0 and a1 > b1) or (a0 = b0 and a1 = b1 and a2 >= b2)
|
|
The only different between >= and > is that >= additional include the (x,y,z) = (a,b,c).
|
|
*/
|
|
func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) {
|
|
lLen, rLen := expression.GetRowLen(l), expression.GetRowLen(r)
|
|
if lLen == 1 && rLen == 1 {
|
|
return er.newFunction(op, types.NewFieldType(mysql.TypeTiny), l, r)
|
|
} else if rLen != lLen {
|
|
return nil, expression.ErrOperandColumns.GenWithStackByArgs(lLen)
|
|
}
|
|
switch op {
|
|
case ast.EQ, ast.NE, ast.NullEQ:
|
|
funcs := make([]expression.Expression, lLen)
|
|
for i := range lLen {
|
|
var err error
|
|
funcs[i], err = er.constructBinaryOpFunction(expression.GetFuncArg(l, i), expression.GetFuncArg(r, i), op)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if op == ast.NE {
|
|
return expression.ComposeDNFCondition(er.sctx, funcs...), nil
|
|
}
|
|
return expression.ComposeCNFCondition(er.sctx, funcs...), nil
|
|
default:
|
|
/*
|
|
The algorithm is as follows:
|
|
1. Iterate over i left columns and construct his own CNF for each left column.
|
|
1.1 Iterate over j (every i-1 columns) to l[j]=r[j]
|
|
1.2 Build current i column with op to l[i] op r[i]
|
|
1.3 Combine 1.1 and 1.2 predicates with AND operator
|
|
2. Combine every i CNF with OR operator.
|
|
*/
|
|
resultDNFList := make([]expression.Expression, 0, lLen)
|
|
// Step 1
|
|
for i := range lLen {
|
|
exprList := make([]expression.Expression, 0, i+1)
|
|
// Step 1.1 build prefix equal conditions
|
|
// (l[0], ... , l[i-1], ...) op (r[0], ... , r[i-1], ...) should be convert to
|
|
// l[0] = r[0] and l[1] = r[1] and ... and l[i-1] = r[i-1]
|
|
for j := range i {
|
|
jExpr, err := er.constructBinaryOpFunction(expression.GetFuncArg(l, j), expression.GetFuncArg(r, j), ast.EQ)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
exprList = append(exprList, jExpr)
|
|
}
|
|
|
|
// Especially, op is GE or LE, the prefix element will be converted to > or <.
|
|
degeneratedOp := op
|
|
if i < lLen-1 {
|
|
switch op {
|
|
case ast.GE:
|
|
degeneratedOp = ast.GT
|
|
case ast.LE:
|
|
degeneratedOp = ast.LT
|
|
}
|
|
}
|
|
// Step 1.2
|
|
currentIndexExpr, err := er.constructBinaryOpFunction(expression.GetFuncArg(l, i), expression.GetFuncArg(r, i), degeneratedOp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
exprList = append(exprList, currentIndexExpr)
|
|
// Step 1.3
|
|
currentExpr := expression.ComposeCNFCondition(er.sctx, exprList...)
|
|
resultDNFList = append(resultDNFList, currentExpr)
|
|
}
|
|
// Step 2
|
|
return expression.ComposeDNFCondition(er.sctx, resultDNFList...), nil
|
|
}
|
|
}
|
|
|
|
// buildSubquery translates the subquery ast to plan.
|
|
// Subquery related hints are returned through hintFlags. Please see comments around HintFlagSemiJoinRewrite and PlanBuilder.subQueryHintFlags for details.
|
|
func (er *expressionRewriter) buildSubquery(ctx context.Context, planCtx *exprRewriterPlanCtx, subq *ast.SubqueryExpr, subqueryCtx subQueryCtx) (np base.LogicalPlan, hintFlags uint64, err error) {
|
|
intest.AssertNotNil(planCtx)
|
|
b := planCtx.builder
|
|
if er.schema != nil {
|
|
outerSchema := er.schema.Clone()
|
|
b.outerSchemas = append(b.outerSchemas, outerSchema)
|
|
b.outerNames = append(b.outerNames, er.names)
|
|
b.outerBlockExpand = append(b.outerBlockExpand, b.currentBlockExpand)
|
|
// set it to nil, otherwise, inner qb will use outer expand meta to rewrite expressions.
|
|
b.currentBlockExpand = nil
|
|
defer func() {
|
|
b.outerSchemas = b.outerSchemas[0 : len(b.outerSchemas)-1]
|
|
b.outerNames = b.outerNames[0 : len(b.outerNames)-1]
|
|
b.currentBlockExpand = b.outerBlockExpand[len(b.outerBlockExpand)-1]
|
|
b.outerBlockExpand = b.outerBlockExpand[0 : len(b.outerBlockExpand)-1]
|
|
}()
|
|
}
|
|
// Store the old value before we enter the subquery and reset they to default value.
|
|
oldSubQCtx := b.subQueryCtx
|
|
b.subQueryCtx = subqueryCtx
|
|
oldHintFlags := b.subQueryHintFlags
|
|
b.subQueryHintFlags = 0
|
|
outerWindowSpecs := b.windowSpecs
|
|
defer func() {
|
|
b.windowSpecs = outerWindowSpecs
|
|
b.subQueryCtx = oldSubQCtx
|
|
b.subQueryHintFlags = oldHintFlags
|
|
}()
|
|
|
|
np, err = b.buildResultSetNode(ctx, subq.Query, false)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
hintFlags = b.subQueryHintFlags
|
|
// Pop the handle map generated by the subquery.
|
|
b.handleHelper.popMap()
|
|
return np, hintFlags, nil
|
|
}
|
|
|
|
func (er *expressionRewriter) requirePlanCtx(inNode ast.Node, detail string) (ctx *exprRewriterPlanCtx, err error) {
|
|
if ctx = er.planCtx; ctx == nil {
|
|
if detail != "" {
|
|
detail = ", " + detail
|
|
}
|
|
err = errors.Errorf("planCtx is required when rewriting node: '%T'%s", inNode, detail)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Enter implements Visitor interface.
|
|
func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
|
|
enterWithPlanCtx := func(fn func(*exprRewriterPlanCtx) (ast.Node, bool)) (ast.Node, bool) {
|
|
planCtx, err := er.requirePlanCtx(inNode, "")
|
|
if err != nil {
|
|
er.err = err
|
|
return inNode, true
|
|
}
|
|
return fn(planCtx)
|
|
}
|
|
|
|
switch v := inNode.(type) {
|
|
case *ast.AggregateFuncExpr:
|
|
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
|
|
index, ok := -1, false
|
|
if planCtx.aggrMap != nil {
|
|
index, ok = planCtx.aggrMap[v]
|
|
}
|
|
if ok {
|
|
// index < 0 indicates this is a correlated aggregate belonging to outer query,
|
|
// for which a correlated column will be created later, so we append a null constant
|
|
// as a temporary result expression.
|
|
if index < 0 {
|
|
er.ctxStackAppend(expression.NewNull(), types.EmptyName)
|
|
} else {
|
|
// index >= 0 indicates this is a regular aggregate column
|
|
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
|
|
}
|
|
return inNode, true
|
|
}
|
|
// replace correlated aggregate in sub-query with its corresponding correlated column
|
|
if col, ok := planCtx.builder.correlatedAggMapper[v]; ok {
|
|
er.ctxStackAppend(col, types.EmptyName)
|
|
return inNode, true
|
|
}
|
|
er.err = plannererrors.ErrInvalidGroupFuncUse
|
|
return inNode, true
|
|
})
|
|
case *ast.ColumnNameExpr:
|
|
if planCtx := er.planCtx; planCtx != nil {
|
|
if index, ok := planCtx.builder.colMapper[v]; ok {
|
|
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
|
|
return inNode, true
|
|
}
|
|
}
|
|
case *ast.CompareSubqueryExpr:
|
|
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
|
|
return er.handleCompareSubquery(er.ctx, planCtx, v)
|
|
})
|
|
case *ast.ExistsSubqueryExpr:
|
|
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
|
|
return er.handleExistSubquery(er.ctx, planCtx, v)
|
|
})
|
|
case *ast.PatternInExpr:
|
|
if v.Sel != nil {
|
|
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
|
|
return er.handleInSubquery(er.ctx, planCtx, v)
|
|
})
|
|
}
|
|
if len(v.List) != 1 {
|
|
break
|
|
}
|
|
// For 10 in ((select * from t)), the parser won't set v.Sel.
|
|
// So we must process this case here.
|
|
x := v.List[0]
|
|
for {
|
|
switch y := x.(type) {
|
|
case *ast.SubqueryExpr:
|
|
v.Sel = y
|
|
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
|
|
return er.handleInSubquery(er.ctx, planCtx, v)
|
|
})
|
|
case *ast.ParenthesesExpr:
|
|
x = y.Expr
|
|
default:
|
|
// Expect its left and right child to be a scalar value.
|
|
er.asScalar = true
|
|
return inNode, false
|
|
}
|
|
}
|
|
case *ast.SubqueryExpr:
|
|
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
|
|
return er.handleScalarSubquery(er.ctx, planCtx, v)
|
|
})
|
|
case *ast.ParenthesesExpr:
|
|
case *ast.ValuesExpr:
|
|
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
|
|
schema, names := er.schema, er.names
|
|
// NOTE: "er.insertPlan != nil" means that we are rewriting the
|
|
// expressions inside the assignment of "INSERT" statement. we have to
|
|
// use the "tableSchema" of that "insertPlan".
|
|
if planCtx.insertPlan != nil {
|
|
schema = planCtx.insertPlan.TableSchema
|
|
names = planCtx.insertPlan.TableColNames
|
|
}
|
|
idx, err := expression.FindFieldName(names, v.Column.Name)
|
|
if err != nil {
|
|
er.err = err
|
|
return inNode, false
|
|
}
|
|
if idx < 0 {
|
|
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Column.Name.OrigColName(), "field list")
|
|
return inNode, false
|
|
}
|
|
col := schema.Columns[idx]
|
|
er.ctxStackAppend(expression.NewValuesFunc(er.sctx, col.Index, col.RetType), types.EmptyName)
|
|
return inNode, true
|
|
})
|
|
case *ast.WindowFuncExpr:
|
|
return enterWithPlanCtx(func(planCtx *exprRewriterPlanCtx) (ast.Node, bool) {
|
|
intest.AssertNotNil(planCtx)
|
|
index, ok := -1, false
|
|
if planCtx.windowMap != nil {
|
|
index, ok = planCtx.windowMap[v]
|
|
}
|
|
if !ok {
|
|
er.err = plannererrors.ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(strings.ToLower(v.Name))
|
|
return inNode, true
|
|
}
|
|
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
|
|
return inNode, true
|
|
})
|
|
case *ast.FuncCallExpr:
|
|
er.asScalar = true
|
|
if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
|
|
er.disableFoldCounter++
|
|
}
|
|
if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok {
|
|
er.tryFoldCounter++
|
|
}
|
|
case *ast.CaseExpr:
|
|
er.asScalar = true
|
|
if _, ok := expression.DisableFoldFunctions["case"]; ok {
|
|
er.disableFoldCounter++
|
|
}
|
|
if _, ok := expression.TryFoldFunctions["case"]; ok {
|
|
er.tryFoldCounter++
|
|
}
|
|
case *ast.BinaryOperationExpr:
|
|
er.asScalar = true
|
|
if v.Op == opcode.LogicAnd || v.Op == opcode.LogicOr {
|
|
er.tryFoldCounter++
|
|
}
|
|
case *ast.SetCollationExpr:
|
|
// Do nothing
|
|
default:
|
|
er.asScalar = true
|
|
}
|
|
return inNode, false
|
|
}
|
|
|
|
func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np base.LogicalPlan, planCtx *exprRewriterPlanCtx, l, r expression.Expression, not, markNoDecorrelate bool) {
|
|
intest.AssertNotNil(planCtx)
|
|
if er.asScalar || not {
|
|
if expression.GetRowLen(r) == 1 {
|
|
rCol := r.(*expression.Column)
|
|
// If both input columns of `!= all / = any` expression are not null, we can treat the expression
|
|
// as normal column equal condition.
|
|
if !expression.ExprNotNull(er.sctx.GetEvalCtx(), l) || !expression.ExprNotNull(er.sctx.GetEvalCtx(), rCol) {
|
|
rColCopy := *rCol
|
|
rColCopy.InOperand = true
|
|
r = &rColCopy
|
|
l = expression.SetExprColumnInOperand(l)
|
|
}
|
|
} else {
|
|
rowFunc := r.(*expression.ScalarFunction)
|
|
rargs := rowFunc.GetArgs()
|
|
args := make([]expression.Expression, 0, len(rargs))
|
|
modified := false
|
|
for i, rarg := range rargs {
|
|
larg := expression.GetFuncArg(l, i)
|
|
if !expression.ExprNotNull(er.sctx.GetEvalCtx(), larg) || !expression.ExprNotNull(er.sctx.GetEvalCtx(), rarg) {
|
|
rCol := rarg.(*expression.Column)
|
|
rColCopy := *rCol
|
|
rColCopy.InOperand = true
|
|
rarg = &rColCopy
|
|
modified = true
|
|
}
|
|
args = append(args, rarg)
|
|
}
|
|
if modified {
|
|
r, er.err = er.newFunction(ast.RowFunc, args[0].GetType(er.sctx.GetEvalCtx()), args...)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
l = expression.SetExprColumnInOperand(l)
|
|
}
|
|
}
|
|
}
|
|
var condition expression.Expression
|
|
condition, er.err = er.constructBinaryOpFunction(l, r, ast.EQ)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
planCtx.plan, er.err = planCtx.builder.buildSemiApply(planCtx.plan, np, []expression.Expression{condition}, er.asScalar, not, false, markNoDecorrelate)
|
|
}
|
|
|
|
func (er *expressionRewriter) handleCompareSubquery(ctx context.Context, planCtx *exprRewriterPlanCtx, v *ast.CompareSubqueryExpr) (ast.Node, bool) {
|
|
intest.AssertNotNil(planCtx)
|
|
b := planCtx.builder
|
|
ci := b.prepareCTECheckForSubQuery()
|
|
defer resetCTECheckForSubQuery(ci)
|
|
v.L.Accept(er)
|
|
if er.err != nil {
|
|
return v, true
|
|
}
|
|
lexpr := er.ctxStack[len(er.ctxStack)-1]
|
|
subq, ok := v.R.(*ast.SubqueryExpr)
|
|
if !ok {
|
|
er.err = errors.Errorf("Unknown compare type %T", v.R)
|
|
return v, true
|
|
}
|
|
np, hintFlags, err := er.buildSubquery(ctx, planCtx, subq, handlingCompareSubquery)
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
corCols := coreusage.ExtractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema())
|
|
noDecorrelate := isNoDecorrelate(planCtx, corCols, hintFlags, handlingCompareSubquery)
|
|
|
|
// Only (a,b,c) = any (...) and (a,b,c) != all (...) can use row expression.
|
|
canMultiCol := (!v.All && v.Op == opcode.EQ) || (v.All && v.Op == opcode.NE)
|
|
if !canMultiCol && (expression.GetRowLen(lexpr) != 1 || np.Schema().Len() != 1) {
|
|
er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
|
|
return v, true
|
|
}
|
|
lLen := expression.GetRowLen(lexpr)
|
|
if lLen != np.Schema().Len() {
|
|
er.err = expression.ErrOperandColumns.GenWithStackByArgs(lLen)
|
|
return v, true
|
|
}
|
|
var rexpr expression.Expression
|
|
if np.Schema().Len() == 1 {
|
|
rexpr = np.Schema().Columns[0]
|
|
} else {
|
|
args := make([]expression.Expression, 0, np.Schema().Len())
|
|
for _, col := range np.Schema().Columns {
|
|
args = append(args, col)
|
|
}
|
|
rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(er.sctx.GetEvalCtx()), args...)
|
|
if er.err != nil {
|
|
return v, true
|
|
}
|
|
}
|
|
|
|
// Lexpr cannot compare with rexpr by different collate
|
|
opString := new(strings.Builder)
|
|
v.Op.Format(opString)
|
|
_, er.err = expression.CheckAndDeriveCollationFromExprs(er.sctx, opString.String(), types.ETInt, lexpr, rexpr)
|
|
if er.err != nil {
|
|
return v, true
|
|
}
|
|
|
|
switch v.Op {
|
|
// Only EQ, NE and NullEQ can be composed with and.
|
|
case opcode.EQ, opcode.NE, opcode.NullEQ:
|
|
if v.Op == opcode.EQ {
|
|
if v.All {
|
|
er.handleEQAll(planCtx, lexpr, rexpr, np, noDecorrelate)
|
|
} else {
|
|
// `a = any(subq)` will be rewriten as `a in (subq)`.
|
|
er.asScalar = true
|
|
er.buildSemiApplyFromEqualSubq(np, planCtx, lexpr, rexpr, false, noDecorrelate)
|
|
if er.err != nil {
|
|
return v, true
|
|
}
|
|
}
|
|
} else if v.Op == opcode.NE {
|
|
if v.All {
|
|
// `a != all(subq)` will be rewriten as `a not in (subq)`.
|
|
er.asScalar = true
|
|
er.buildSemiApplyFromEqualSubq(np, planCtx, lexpr, rexpr, true, noDecorrelate)
|
|
if er.err != nil {
|
|
return v, true
|
|
}
|
|
} else {
|
|
er.handleNEAny(planCtx, lexpr, rexpr, np, noDecorrelate)
|
|
}
|
|
} else {
|
|
// TODO: Support this in future.
|
|
er.err = errors.New("We don't support <=> all or <=> any now")
|
|
return v, true
|
|
}
|
|
default:
|
|
// When < all or > any , the agg function should use min.
|
|
useMin := ((v.Op == opcode.LT || v.Op == opcode.LE) && v.All) || ((v.Op == opcode.GT || v.Op == opcode.GE) && !v.All)
|
|
er.handleOtherComparableSubq(planCtx, lexpr, rexpr, np, useMin, v.Op.String(), v.All, noDecorrelate)
|
|
}
|
|
if er.asScalar {
|
|
// The parent expression only use the last column in schema, which represents whether the condition is matched.
|
|
er.ctxStack[len(er.ctxStack)-1] = planCtx.plan.Schema().Columns[planCtx.plan.Schema().Len()-1]
|
|
er.ctxNameStk[len(er.ctxNameStk)-1] = planCtx.plan.OutputNames()[planCtx.plan.Schema().Len()-1]
|
|
}
|
|
return v, true
|
|
}
|
|
|
|
// handleOtherComparableSubq handles the queries like < any, < max, etc. For example, if the query is t.id < any (select s.id from s),
|
|
// it will be rewrote to t.id < (select max(s.id) from s).
|
|
func (er *expressionRewriter) handleOtherComparableSubq(planCtx *exprRewriterPlanCtx, lexpr, rexpr expression.Expression, np base.LogicalPlan, useMin bool, cmpFunc string, all, markNoDecorrelate bool) {
|
|
intest.AssertNotNil(planCtx)
|
|
plan4Agg := logicalop.LogicalAggregation{}.Init(planCtx.builder.ctx, planCtx.builder.getSelectOffset())
|
|
if hintinfo := planCtx.builder.TableHints(); hintinfo != nil {
|
|
plan4Agg.PreferAggType = hintinfo.PreferAggType
|
|
plan4Agg.PreferAggToCop = hintinfo.PreferAggToCop
|
|
}
|
|
plan4Agg.SetChildren(np)
|
|
|
|
// Create a "max" or "min" aggregation.
|
|
funcName := ast.AggFuncMax
|
|
if useMin {
|
|
funcName = ast.AggFuncMin
|
|
}
|
|
funcMaxOrMin, err := aggregation.NewAggFuncDesc(planCtx.builder.ctx.GetExprCtx(), funcName, []expression.Expression{rexpr}, false)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
|
|
// Create a column and append it to the schema of that aggregation.
|
|
colMaxOrMin := &expression.Column{
|
|
UniqueID: planCtx.builder.ctx.GetSessionVars().AllocPlanColumnID(),
|
|
RetType: funcMaxOrMin.RetTp,
|
|
}
|
|
colMaxOrMin.SetCoercibility(rexpr.Coercibility())
|
|
schema := expression.NewSchema(colMaxOrMin)
|
|
|
|
plan4Agg.SetOutputNames(append(plan4Agg.OutputNames(), types.EmptyName))
|
|
plan4Agg.SetSchema(schema)
|
|
plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin}
|
|
|
|
cond := expression.NewFunctionInternal(er.sctx, cmpFunc, types.NewFieldType(mysql.TypeTiny), lexpr, colMaxOrMin)
|
|
er.buildQuantifierPlan(planCtx, plan4Agg, cond, lexpr, rexpr, all, markNoDecorrelate)
|
|
}
|
|
|
|
// buildQuantifierPlan adds extra condition for any / all subquery.
|
|
func (er *expressionRewriter) buildQuantifierPlan(planCtx *exprRewriterPlanCtx, plan4Agg *logicalop.LogicalAggregation, cond, lexpr, rexpr expression.Expression, all, markNoDecorrelate bool) {
|
|
intest.AssertNotNil(planCtx)
|
|
innerIsNull := expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr)
|
|
outerIsNull := expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr)
|
|
exprCtx := planCtx.builder.ctx.GetExprCtx()
|
|
funcSum, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
sessVars := planCtx.builder.ctx.GetSessionVars()
|
|
colSum := &expression.Column{
|
|
UniqueID: sessVars.AllocPlanColumnID(),
|
|
RetType: funcSum.RetTp,
|
|
}
|
|
plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum)
|
|
plan4Agg.Schema().Append(colSum)
|
|
innerHasNull := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.NewZero())
|
|
|
|
// Build `count(1)` aggregation to check if subquery is empty.
|
|
funcCount, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncCount, []expression.Expression{expression.NewOne()}, false)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
colCount := &expression.Column{
|
|
UniqueID: sessVars.AllocPlanColumnID(),
|
|
RetType: funcCount.RetTp,
|
|
}
|
|
plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount)
|
|
plan4Agg.Schema().Append(colCount)
|
|
|
|
if all {
|
|
// All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it
|
|
// should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true).
|
|
innerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.NewNull(), expression.NewOne())
|
|
cond = expression.ComposeCNFCondition(er.sctx, cond, innerNullChecker)
|
|
// If the subquery is empty, it should always return true.
|
|
emptyChecker := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.NewZero())
|
|
// If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`.
|
|
outerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.NewNull(), expression.NewZero())
|
|
cond = expression.ComposeDNFCondition(er.sctx, cond, emptyChecker, outerNullChecker)
|
|
} else {
|
|
// For "any" expression, if the subquery has null and the cond returns false, the result should be NULL.
|
|
// Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)`
|
|
innerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.NewNull(), expression.NewZero())
|
|
cond = expression.ComposeDNFCondition(er.sctx, cond, innerNullChecker)
|
|
// If the subquery is empty, it should always return false.
|
|
emptyChecker := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colCount, expression.NewZero())
|
|
// If outer key is null, and subquery is not empty, it should return null.
|
|
outerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.NewNull(), expression.NewOne())
|
|
cond = expression.ComposeCNFCondition(er.sctx, cond, emptyChecker, outerNullChecker)
|
|
}
|
|
|
|
// TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions.
|
|
// plan4Agg.buildProjectionIfNecessary()
|
|
if !er.asScalar {
|
|
// For Semi LogicalApply without aux column, the result is no matter false or null. So we can add it to join predicate.
|
|
planCtx.plan, er.err = planCtx.builder.buildSemiApply(planCtx.plan, plan4Agg, []expression.Expression{cond}, false, false, false, markNoDecorrelate)
|
|
return
|
|
}
|
|
// If we treat the result as a scalar value, we will add a projection with a extra column to output true, false or null.
|
|
outerSchemaLen := planCtx.plan.Schema().Len()
|
|
planCtx.plan = planCtx.builder.buildApplyWithJoinType(planCtx.plan, plan4Agg, base.InnerJoin, markNoDecorrelate)
|
|
joinSchema := planCtx.plan.Schema()
|
|
proj := logicalop.LogicalProjection{
|
|
Exprs: expression.Column2Exprs(joinSchema.Clone().Columns[:outerSchemaLen]),
|
|
}.Init(planCtx.builder.ctx, planCtx.builder.getSelectOffset())
|
|
proj.SetOutputNames(make([]*types.FieldName, outerSchemaLen, outerSchemaLen+1))
|
|
names := proj.OutputNames()
|
|
copy(names, planCtx.plan.OutputNames())
|
|
proj.SetOutputNames(names)
|
|
proj.SetSchema(expression.NewSchema(joinSchema.Clone().Columns[:outerSchemaLen]...))
|
|
proj.Exprs = append(proj.Exprs, cond)
|
|
proj.Schema().Append(&expression.Column{
|
|
UniqueID: sessVars.AllocPlanColumnID(),
|
|
RetType: cond.GetType(er.sctx.GetEvalCtx()),
|
|
})
|
|
proj.SetOutputNames(append(proj.OutputNames(), types.EmptyName))
|
|
proj.SetChildren(planCtx.plan)
|
|
planCtx.plan = proj
|
|
}
|
|
|
|
// handleNEAny handles the case of != any. For example, if the query is t.id != any (select s.id from s), it will be rewrote to
|
|
// t.id != s.id or count(distinct s.id) > 1 or [any checker]. If there are two different values in s.id ,
|
|
// there must exist a s.id that doesn't equal to t.id.
|
|
func (er *expressionRewriter) handleNEAny(planCtx *exprRewriterPlanCtx, lexpr, rexpr expression.Expression, np base.LogicalPlan, markNoDecorrelate bool) {
|
|
intest.AssertNotNil(planCtx)
|
|
sctx := planCtx.builder.ctx
|
|
exprCtx := sctx.GetExprCtx()
|
|
// If there is NULL in s.id column, s.id should be the value that isn't null in condition t.id != s.id.
|
|
// So use function max to filter NULL.
|
|
maxFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncMax, []expression.Expression{rexpr}, false)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
countFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncCount, []expression.Expression{rexpr}, true)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
plan4Agg := logicalop.LogicalAggregation{
|
|
AggFuncs: []*aggregation.AggFuncDesc{maxFunc, countFunc},
|
|
}.Init(sctx, planCtx.builder.getSelectOffset())
|
|
if hintinfo := planCtx.builder.TableHints(); hintinfo != nil {
|
|
plan4Agg.PreferAggType = hintinfo.PreferAggType
|
|
plan4Agg.PreferAggToCop = hintinfo.PreferAggToCop
|
|
}
|
|
plan4Agg.SetChildren(np)
|
|
maxResultCol := &expression.Column{
|
|
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
|
|
RetType: maxFunc.RetTp,
|
|
}
|
|
maxResultCol.SetCoercibility(rexpr.Coercibility())
|
|
count := &expression.Column{
|
|
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
|
|
RetType: countFunc.RetTp,
|
|
}
|
|
plan4Agg.SetOutputNames(append(plan4Agg.OutputNames(), types.EmptyName, types.EmptyName))
|
|
plan4Agg.SetSchema(expression.NewSchema(maxResultCol, count))
|
|
gtFunc := expression.NewFunctionInternal(er.sctx, ast.GT, types.NewFieldType(mysql.TypeTiny), count, expression.NewOne())
|
|
neCond := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), lexpr, maxResultCol)
|
|
cond := expression.ComposeDNFCondition(er.sctx, gtFunc, neCond)
|
|
er.buildQuantifierPlan(planCtx, plan4Agg, cond, lexpr, rexpr, false, markNoDecorrelate)
|
|
}
|
|
|
|
// handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to
|
|
// t.id = (select s.id from s having count(distinct s.id) <= 1 and [all checker]).
|
|
func (er *expressionRewriter) handleEQAll(planCtx *exprRewriterPlanCtx, lexpr, rexpr expression.Expression, np base.LogicalPlan, markNoDecorrelate bool) {
|
|
intest.AssertNotNil(planCtx)
|
|
sctx := planCtx.builder.ctx
|
|
exprCtx := sctx.GetExprCtx()
|
|
// If there is NULL in s.id column, s.id should be the value that isn't null in condition t.id == s.id.
|
|
// So use function max to filter NULL.
|
|
maxFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncMax, []expression.Expression{rexpr}, false)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
countFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncCount, []expression.Expression{rexpr}, true)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
plan4Agg := logicalop.LogicalAggregation{
|
|
AggFuncs: []*aggregation.AggFuncDesc{maxFunc, countFunc},
|
|
}.Init(sctx, planCtx.builder.getSelectOffset())
|
|
if hintinfo := planCtx.builder.TableHints(); hintinfo != nil {
|
|
plan4Agg.PreferAggType = hintinfo.PreferAggType
|
|
plan4Agg.PreferAggToCop = hintinfo.PreferAggToCop
|
|
}
|
|
plan4Agg.SetChildren(np)
|
|
plan4Agg.SetOutputNames(append(plan4Agg.OutputNames(), types.EmptyName))
|
|
|
|
maxResultCol := &expression.Column{
|
|
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
|
|
RetType: maxFunc.RetTp,
|
|
}
|
|
maxResultCol.SetCoercibility(rexpr.Coercibility())
|
|
plan4Agg.SetOutputNames(append(plan4Agg.OutputNames(), types.EmptyName))
|
|
count := &expression.Column{
|
|
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
|
|
RetType: countFunc.RetTp,
|
|
}
|
|
plan4Agg.SetSchema(expression.NewSchema(maxResultCol, count))
|
|
leFunc := expression.NewFunctionInternal(er.sctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.NewOne())
|
|
eqCond := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, maxResultCol)
|
|
cond := expression.ComposeCNFCondition(er.sctx, leFunc, eqCond)
|
|
er.buildQuantifierPlan(planCtx, plan4Agg, cond, lexpr, rexpr, true, markNoDecorrelate)
|
|
}
|
|
|
|
func (er *expressionRewriter) handleExistSubquery(ctx context.Context, planCtx *exprRewriterPlanCtx, v *ast.ExistsSubqueryExpr) (ast.Node, bool) {
|
|
intest.AssertNotNil(planCtx)
|
|
b := planCtx.builder
|
|
ci := b.prepareCTECheckForSubQuery()
|
|
defer resetCTECheckForSubQuery(ci)
|
|
subq, ok := v.Sel.(*ast.SubqueryExpr)
|
|
if !ok {
|
|
er.err = errors.Errorf("Unknown exists type %T", v.Sel)
|
|
return v, true
|
|
}
|
|
np, hintFlags, err := er.buildSubquery(ctx, planCtx, subq, handlingExistsSubquery)
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
// Add LIMIT 1 when noDecorrelate is true for EXISTS subqueries to enable early exit
|
|
corCols := coreusage.ExtractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema())
|
|
noDecorrelate := isNoDecorrelate(planCtx, corCols, hintFlags, handlingExistsSubquery)
|
|
if noDecorrelate {
|
|
// Only add LIMIT 1 if the query doesn't already contain a LIMIT clause
|
|
if !hasLimit(np) {
|
|
limitClause := &ast.Limit{
|
|
Count: ast.NewValueExpr(1, "", ""),
|
|
}
|
|
var err error
|
|
np, err = planCtx.builder.buildLimit(np, limitClause)
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
}
|
|
}
|
|
np = er.popExistsSubPlan(planCtx, np)
|
|
semiJoinRewrite := hintFlags&hint.HintFlagSemiJoinRewrite > 0
|
|
if semiJoinRewrite && noDecorrelate {
|
|
b.ctx.GetSessionVars().StmtCtx.SetHintWarning(
|
|
"NO_DECORRELATE() and SEMI_JOIN_REWRITE() are in conflict. Both will be ineffective.")
|
|
noDecorrelate = false
|
|
semiJoinRewrite = false
|
|
}
|
|
|
|
if b.disableSubQueryPreprocessing || len(coreusage.ExtractCorrelatedCols4LogicalPlan(np)) > 0 || hasCTEConsumerInSubPlan(np) {
|
|
planCtx.plan, er.err = b.buildSemiApply(planCtx.plan, np, nil, er.asScalar, v.Not, semiJoinRewrite, noDecorrelate)
|
|
if er.err != nil || !er.asScalar {
|
|
return v, true
|
|
}
|
|
er.ctxStackAppend(planCtx.plan.Schema().Columns[planCtx.plan.Schema().Len()-1], planCtx.plan.OutputNames()[planCtx.plan.Schema().Len()-1])
|
|
} else {
|
|
// We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily.
|
|
nthPlanBackup := b.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan
|
|
b.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = -1
|
|
physicalPlan, _, err := DoOptimize(ctx, planCtx.builder.ctx, b.optFlag, np)
|
|
b.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = nthPlanBackup
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
if b.ctx.GetSessionVars().StmtCtx.InExplainStmt && !b.ctx.GetSessionVars().StmtCtx.InExplainAnalyzeStmt && b.ctx.GetSessionVars().ExplainNonEvaledSubQuery {
|
|
newColID := b.ctx.GetSessionVars().AllocPlanColumnID()
|
|
subqueryCtx := ScalarSubqueryEvalCtx{
|
|
scalarSubQuery: physicalPlan,
|
|
ctx: ctx,
|
|
is: b.is,
|
|
outputColIDs: []int64{newColID},
|
|
}.Init(b.ctx, np.QueryBlockOffset())
|
|
scalarSubQ := &ScalarSubQueryExpr{
|
|
scalarSubqueryColID: newColID,
|
|
evalCtx: subqueryCtx,
|
|
}
|
|
scalarSubQ.RetType = np.Schema().Columns[0].GetType(er.sctx.GetEvalCtx())
|
|
scalarSubQ.SetCoercibility(np.Schema().Columns[0].Coercibility())
|
|
b.ctx.GetSessionVars().RegisterScalarSubQ(subqueryCtx)
|
|
if v.Not {
|
|
notWrapped, err := expression.NewFunction(b.ctx.GetExprCtx(), ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), scalarSubQ)
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
er.ctxStackAppend(notWrapped, types.EmptyName)
|
|
return v, true
|
|
}
|
|
er.ctxStackAppend(scalarSubQ, types.EmptyName)
|
|
return v, true
|
|
}
|
|
// register the subquery plan but continue with normal execution
|
|
subqueryCtx := ScalarSubqueryEvalCtx{
|
|
scalarSubQuery: physicalPlan,
|
|
ctx: ctx,
|
|
is: b.is,
|
|
}.Init(planCtx.builder.ctx, np.QueryBlockOffset())
|
|
newColIDs := make([]int64, 0, np.Schema().Len())
|
|
for range np.Schema().Columns {
|
|
newColID := planCtx.builder.ctx.GetSessionVars().AllocPlanColumnID()
|
|
newColIDs = append(newColIDs, newColID)
|
|
}
|
|
subqueryCtx.outputColIDs = newColIDs
|
|
|
|
planCtx.builder.ctx.GetSessionVars().RegisterScalarSubQ(subqueryCtx)
|
|
|
|
row, err := EvalSubqueryFirstRow(ctx, physicalPlan, b.is, b.ctx)
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
if (row != nil && !v.Not) || (row == nil && v.Not) {
|
|
er.ctxStackAppend(expression.NewSignedOne(), types.EmptyName)
|
|
} else {
|
|
er.ctxStackAppend(expression.NewSignedZero(), types.EmptyName)
|
|
}
|
|
}
|
|
return v, true
|
|
}
|
|
|
|
// popExistsSubPlan will remove the useless plan in exist's child.
|
|
// See comments inside the method for more details.
|
|
func (*expressionRewriter) popExistsSubPlan(planCtx *exprRewriterPlanCtx, p base.LogicalPlan) base.LogicalPlan {
|
|
intest.AssertNotNil(planCtx)
|
|
out:
|
|
for {
|
|
switch plan := p.(type) {
|
|
// This can be removed when in exists clause,
|
|
// e.g. exists(select count(*) from t order by a) is equal to exists t.
|
|
case *logicalop.LogicalProjection, *logicalop.LogicalSort:
|
|
p = p.Children()[0]
|
|
case *logicalop.LogicalAggregation:
|
|
if len(plan.GroupByItems) == 0 {
|
|
p = logicalop.LogicalTableDual{RowCount: 1}.Init(planCtx.builder.ctx, planCtx.builder.getSelectOffset())
|
|
break out
|
|
}
|
|
p = p.Children()[0]
|
|
default:
|
|
break out
|
|
}
|
|
}
|
|
return p
|
|
}
|
|
|
|
func (er *expressionRewriter) handleInSubquery(ctx context.Context, planCtx *exprRewriterPlanCtx, v *ast.PatternInExpr) (ast.Node, bool) {
|
|
intest.AssertNotNil(planCtx)
|
|
ci := planCtx.builder.prepareCTECheckForSubQuery()
|
|
defer resetCTECheckForSubQuery(ci)
|
|
asScalar := er.asScalar
|
|
er.asScalar = true
|
|
v.Expr.Accept(er)
|
|
if er.err != nil {
|
|
return v, true
|
|
}
|
|
lexpr := er.ctxStack[len(er.ctxStack)-1]
|
|
subq, ok := v.Sel.(*ast.SubqueryExpr)
|
|
if !ok {
|
|
er.err = errors.Errorf("Unknown compare type %T", v.Sel)
|
|
return v, true
|
|
}
|
|
np, hintFlags, err := er.buildSubquery(ctx, planCtx, subq, handlingInSubquery)
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
lLen := expression.GetRowLen(lexpr)
|
|
if lLen != np.Schema().Len() {
|
|
er.err = expression.ErrOperandColumns.GenWithStackByArgs(lLen)
|
|
return v, true
|
|
}
|
|
var rexpr expression.Expression
|
|
if np.Schema().Len() == 1 {
|
|
rexpr = np.Schema().Columns[0]
|
|
rCol := rexpr.(*expression.Column)
|
|
// For AntiSemiJoin/LeftOuterSemiJoin/AntiLeftOuterSemiJoin, we cannot treat `in` expression as
|
|
// normal column equal condition, so we specially mark the inner operand here.
|
|
if v.Not || asScalar {
|
|
// If both input columns of `in` expression are not null, we can treat the expression
|
|
// as normal column equal condition instead. Otherwise, mark the left and right side.
|
|
// eg: for some optimization, the column substitute in right side in projection elimination
|
|
// will cause case like <lcol EQ rcol(inOperand)> as <lcol EQ constant> which is not
|
|
// a valid null-aware EQ. (null in lcol still need to be null-aware)
|
|
if !expression.ExprNotNull(er.sctx.GetEvalCtx(), lexpr) || !expression.ExprNotNull(er.sctx.GetEvalCtx(), rCol) {
|
|
rColCopy := *rCol
|
|
rColCopy.InOperand = true
|
|
rexpr = &rColCopy
|
|
lexpr = expression.SetExprColumnInOperand(lexpr)
|
|
}
|
|
}
|
|
} else {
|
|
args := make([]expression.Expression, 0, np.Schema().Len())
|
|
for i, col := range np.Schema().Columns {
|
|
if v.Not || asScalar {
|
|
larg := expression.GetFuncArg(lexpr, i)
|
|
// If both input columns of `in` expression are not null, we can treat the expression
|
|
// as normal column equal condition instead. Otherwise, mark the left and right side.
|
|
if !expression.ExprNotNull(er.sctx.GetEvalCtx(), larg) || !expression.ExprNotNull(er.sctx.GetEvalCtx(), col) {
|
|
rarg := *col
|
|
rarg.InOperand = true
|
|
col = &rarg
|
|
if larg != nil {
|
|
lexpr.(*expression.ScalarFunction).GetArgs()[i] = expression.SetExprColumnInOperand(larg)
|
|
}
|
|
}
|
|
}
|
|
args = append(args, col)
|
|
}
|
|
rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(er.sctx.GetEvalCtx()), args...)
|
|
if er.err != nil {
|
|
return v, true
|
|
}
|
|
}
|
|
checkCondition, err := er.constructBinaryOpFunction(lexpr, rexpr, ast.EQ)
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
|
|
// If the leftKey and the rightKey have different collations, don't convert the sub-query to an inner-join
|
|
// since when converting we will add a distinct-agg upon the right child and this distinct-agg doesn't have the right collation.
|
|
// To keep it simple, we forbid this converting if they have different collations.
|
|
// tested by TestCollateSubQuery.
|
|
lt, rt := lexpr.GetType(er.sctx.GetEvalCtx()), rexpr.GetType(er.sctx.GetEvalCtx())
|
|
collFlag := collate.CompatibleCollate(lt.GetCollate(), rt.GetCollate())
|
|
corCols := coreusage.ExtractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema())
|
|
noDecorrelate := isNoDecorrelate(planCtx, corCols, hintFlags, handlingInSubquery)
|
|
|
|
// If it's not the form of `not in (SUBQUERY)`,
|
|
// and has no correlated column from the current level plan(if the correlated column is from upper level,
|
|
// we can treat it as constant, because the upper LogicalApply cannot be eliminated since current node is a join node),
|
|
// and don't need to append a scalar value, we can rewrite it to inner join.
|
|
if planCtx.builder.ctx.GetSessionVars().GetAllowInSubqToJoinAndAgg() && !v.Not && !asScalar && len(corCols) == 0 && collFlag {
|
|
// We need to try to eliminate the agg and the projection produced by this operation.
|
|
planCtx.builder.optFlag |= rule.FlagEliminateAgg
|
|
planCtx.builder.optFlag |= rule.FlagEliminateProjection
|
|
planCtx.builder.optFlag |= rule.FlagJoinReOrder
|
|
planCtx.builder.optFlag |= rule.FlagEmptySelectionEliminator
|
|
// Build distinct for the inner query.
|
|
agg, err := planCtx.builder.buildDistinct(np, np.Schema().Len())
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
// Build inner join above the aggregation.
|
|
join := logicalop.LogicalJoin{JoinType: base.InnerJoin}.Init(planCtx.builder.ctx, planCtx.builder.getSelectOffset())
|
|
join.SetChildren(planCtx.plan, agg)
|
|
join.SetSchema(expression.MergeSchema(planCtx.plan.Schema(), agg.Schema()))
|
|
join.SetOutputNames(make([]*types.FieldName, planCtx.plan.Schema().Len()+agg.Schema().Len()))
|
|
copy(join.OutputNames(), planCtx.plan.OutputNames())
|
|
copy(join.OutputNames()[planCtx.plan.Schema().Len():], agg.OutputNames())
|
|
join.AttachOnConds(expression.SplitCNFItems(checkCondition))
|
|
// set FullSchema and FullNames for this join
|
|
if left, ok := planCtx.plan.(*logicalop.LogicalJoin); ok && left.FullSchema != nil {
|
|
join.FullSchema = left.FullSchema
|
|
join.FullNames = left.FullNames
|
|
}
|
|
// Set join hint for this join.
|
|
if planCtx.builder.TableHints() != nil {
|
|
join.SetPreferredJoinTypeAndOrder(planCtx.builder.TableHints())
|
|
}
|
|
planCtx.plan = join
|
|
} else {
|
|
semiRewrite := hintFlags&hint.HintFlagSemiJoinRewrite > 0
|
|
planCtx.plan, er.err = planCtx.builder.buildSemiApply(planCtx.plan, np, expression.SplitCNFItems(checkCondition), asScalar, v.Not, semiRewrite, noDecorrelate)
|
|
if er.err != nil {
|
|
return v, true
|
|
}
|
|
}
|
|
|
|
er.ctxStackPop(1)
|
|
if asScalar {
|
|
col := planCtx.plan.Schema().Columns[planCtx.plan.Schema().Len()-1]
|
|
er.ctxStackAppend(col, planCtx.plan.OutputNames()[planCtx.plan.Schema().Len()-1])
|
|
}
|
|
return v, true
|
|
}
|
|
|
|
func isNoDecorrelate(planCtx *exprRewriterPlanCtx, corCols []*expression.CorrelatedColumn, hintFlags uint64, sCtx subQueryCtx) bool {
|
|
noDecorrelate := hintFlags&hint.HintFlagNoDecorrelate > 0
|
|
|
|
if len(corCols) == 0 {
|
|
if noDecorrelate {
|
|
planCtx.builder.ctx.GetSessionVars().StmtCtx.SetHintWarning(
|
|
"NO_DECORRELATE() is inapplicable because there are no correlated columns.")
|
|
noDecorrelate = false
|
|
}
|
|
} else {
|
|
semiJoinRewrite := hintFlags&hint.HintFlagSemiJoinRewrite > 0
|
|
// We can't override noDecorrelate via the variable for EXISTS subqueries with semi join rewrite
|
|
// as this will cause a conflict that will result in both being disabled in later code
|
|
// SemiJoinRewrite does not check the variable TiDBOptEnableSemiJoinRewrite.
|
|
// If that variable is enabled - we can still choose NOT to decorrelate here.
|
|
if !(semiJoinRewrite && sCtx == handlingExistsSubquery) {
|
|
// Only support scalar and exists subqueries
|
|
validSubqType := sCtx == handlingScalarSubquery || sCtx == handlingExistsSubquery
|
|
if validSubqType && planCtx.curClause == fieldList { // subquery is in the select list
|
|
planCtx.builder.ctx.GetSessionVars().RecordRelevantOptVar(vardef.TiDBOptEnableNoDecorrelateInSelect)
|
|
// If it isn't already enabled via hint, and variable is set, then enable it
|
|
if !noDecorrelate && planCtx.builder.ctx.GetSessionVars().EnableNoDecorrelateInSelect {
|
|
noDecorrelate = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return noDecorrelate
|
|
}
|
|
|
|
func (er *expressionRewriter) handleScalarSubquery(ctx context.Context, planCtx *exprRewriterPlanCtx, v *ast.SubqueryExpr) (ast.Node, bool) {
|
|
intest.AssertNotNil(planCtx)
|
|
ci := planCtx.builder.prepareCTECheckForSubQuery()
|
|
defer resetCTECheckForSubQuery(ci)
|
|
np, hintFlags, err := er.buildSubquery(ctx, planCtx, v, handlingScalarSubquery)
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
np = planCtx.builder.buildMaxOneRow(np)
|
|
correlatedColumn := coreusage.ExtractCorColumnsBySchema4LogicalPlan(np, planCtx.plan.Schema())
|
|
noDecorrelate := isNoDecorrelate(planCtx, correlatedColumn, hintFlags, handlingScalarSubquery)
|
|
|
|
if planCtx.builder.disableSubQueryPreprocessing || len(coreusage.ExtractCorrelatedCols4LogicalPlan(np)) > 0 || hasCTEConsumerInSubPlan(np) {
|
|
planCtx.plan = planCtx.builder.buildApplyWithJoinType(planCtx.plan, np, base.LeftOuterJoin, noDecorrelate)
|
|
if np.Schema().Len() > 1 {
|
|
newCols := make([]expression.Expression, 0, np.Schema().Len())
|
|
for _, col := range np.Schema().Columns {
|
|
newCols = append(newCols, col)
|
|
}
|
|
expr, err1 := er.newFunction(ast.RowFunc, newCols[0].GetType(er.sctx.GetEvalCtx()), newCols...)
|
|
if err1 != nil {
|
|
er.err = err1
|
|
return v, true
|
|
}
|
|
er.ctxStackAppend(expr, types.EmptyName)
|
|
} else {
|
|
er.ctxStackAppend(planCtx.plan.Schema().Columns[planCtx.plan.Schema().Len()-1], planCtx.plan.OutputNames()[planCtx.plan.Schema().Len()-1])
|
|
}
|
|
return v, true
|
|
}
|
|
// We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily.
|
|
nthPlanBackup := planCtx.builder.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan
|
|
planCtx.builder.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = -1
|
|
physicalPlan, _, err := DoOptimize(ctx, planCtx.builder.ctx, planCtx.builder.optFlag, np)
|
|
planCtx.builder.ctx.GetSessionVars().StmtCtx.StmtHints.ForceNthPlan = nthPlanBackup
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
if planCtx.builder.ctx.GetSessionVars().StmtCtx.InExplainStmt && !planCtx.builder.ctx.GetSessionVars().StmtCtx.InExplainAnalyzeStmt && planCtx.builder.ctx.GetSessionVars().ExplainNonEvaledSubQuery {
|
|
subqueryCtx := ScalarSubqueryEvalCtx{
|
|
scalarSubQuery: physicalPlan,
|
|
ctx: ctx,
|
|
is: planCtx.builder.is,
|
|
}.Init(planCtx.builder.ctx, np.QueryBlockOffset())
|
|
newColIDs := make([]int64, 0, np.Schema().Len())
|
|
newScalarSubQueryExprs := make([]expression.Expression, 0, np.Schema().Len())
|
|
for _, col := range np.Schema().Columns {
|
|
newColID := planCtx.builder.ctx.GetSessionVars().AllocPlanColumnID()
|
|
scalarSubQ := &ScalarSubQueryExpr{
|
|
scalarSubqueryColID: newColID,
|
|
evalCtx: subqueryCtx,
|
|
}
|
|
scalarSubQ.RetType = col.RetType
|
|
scalarSubQ.SetCoercibility(col.Coercibility())
|
|
newColIDs = append(newColIDs, newColID)
|
|
newScalarSubQueryExprs = append(newScalarSubQueryExprs, scalarSubQ)
|
|
}
|
|
subqueryCtx.outputColIDs = newColIDs
|
|
|
|
planCtx.builder.ctx.GetSessionVars().RegisterScalarSubQ(subqueryCtx)
|
|
if len(newScalarSubQueryExprs) == 1 {
|
|
er.ctxStackAppend(newScalarSubQueryExprs[0], types.EmptyName)
|
|
} else {
|
|
rowFunc, err := er.newFunction(ast.RowFunc, newScalarSubQueryExprs[0].GetType(er.sctx.GetEvalCtx()), newScalarSubQueryExprs...)
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
er.ctxStack = append(er.ctxStack, rowFunc)
|
|
}
|
|
return v, true
|
|
}
|
|
|
|
// register the subquery plan but continue with normal execution
|
|
subqueryCtx := ScalarSubqueryEvalCtx{
|
|
scalarSubQuery: physicalPlan,
|
|
ctx: ctx,
|
|
is: planCtx.builder.is,
|
|
}.Init(planCtx.builder.ctx, np.QueryBlockOffset())
|
|
newColIDs := make([]int64, 0, np.Schema().Len())
|
|
for range np.Schema().Columns {
|
|
newColID := planCtx.builder.ctx.GetSessionVars().AllocPlanColumnID()
|
|
newColIDs = append(newColIDs, newColID)
|
|
}
|
|
subqueryCtx.outputColIDs = newColIDs
|
|
|
|
planCtx.builder.ctx.GetSessionVars().RegisterScalarSubQ(subqueryCtx)
|
|
row, err := EvalSubqueryFirstRow(ctx, physicalPlan, planCtx.builder.is, planCtx.builder.ctx)
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
newCols := make([]expression.Expression, 0, np.Schema().Len())
|
|
for i, data := range row {
|
|
constant := &expression.Constant{
|
|
Value: data,
|
|
RetType: np.Schema().Columns[i].GetType(er.sctx.GetEvalCtx()),
|
|
SubqueryRefID: newColIDs[i],
|
|
}
|
|
constant.SetCoercibility(np.Schema().Columns[i].Coercibility())
|
|
newCols = append(newCols, constant)
|
|
}
|
|
|
|
if np.Schema().Len() > 1 {
|
|
expr, err := er.newFunction(ast.RowFunc, newCols[0].GetType(er.sctx.GetEvalCtx()), newCols...)
|
|
if err != nil {
|
|
er.err = err
|
|
return v, true
|
|
}
|
|
er.ctxStackAppend(expr, types.EmptyName)
|
|
} else {
|
|
er.ctxStackAppend(newCols[0], types.EmptyName)
|
|
}
|
|
return v, true
|
|
}
|
|
|
|
func hasCTEConsumerInSubPlan(p base.LogicalPlan) bool {
|
|
if _, ok := p.(*logicalop.LogicalCTE); ok {
|
|
return true
|
|
}
|
|
return slices.ContainsFunc(p.Children(), hasCTEConsumerInSubPlan)
|
|
}
|
|
|
|
func initConstantRepertoire(ctx expression.EvalContext, c *expression.Constant) {
|
|
c.SetRepertoire(expression.ASCII)
|
|
if c.GetType(ctx).EvalType() == types.ETString {
|
|
for _, b := range c.Value.GetBytes() {
|
|
// if any character in constant is not ascii, set the repertoire to UNICODE.
|
|
if b >= 0x80 {
|
|
c.SetRepertoire(expression.UNICODE)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (er *expressionRewriter) adjustUTF8MB4Collation(tp *types.FieldType) {
|
|
if tp.GetFlag()&mysql.UnderScoreCharsetFlag > 0 && charset.CharsetUTF8MB4 == tp.GetCharset() {
|
|
tp.SetCollate(er.sctx.GetDefaultCollationForUTF8MB4())
|
|
}
|
|
}
|
|
|
|
// Leave implements Visitor interface.
|
|
func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok bool) {
|
|
if er.err != nil {
|
|
return retNode, false
|
|
}
|
|
var inNode = originInNode
|
|
if er.preprocess != nil {
|
|
inNode = er.preprocess(inNode)
|
|
}
|
|
|
|
withPlanCtx := func(fn func(*exprRewriterPlanCtx), detail string) {
|
|
planCtx, err := er.requirePlanCtx(inNode, detail)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
fn(planCtx)
|
|
}
|
|
|
|
switch v := inNode.(type) {
|
|
case *ast.AggregateFuncExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause,
|
|
*ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr, *ast.WindowFuncExpr, *ast.TableNameExpr:
|
|
case *driver.ValueExpr:
|
|
// set right not null flag for constant value
|
|
retType := v.Type.Clone()
|
|
switch v.Datum.Kind() {
|
|
case types.KindNull:
|
|
retType.DelFlag(mysql.NotNullFlag)
|
|
default:
|
|
retType.AddFlag(mysql.NotNullFlag)
|
|
}
|
|
v.Datum.SetValue(v.Datum.GetValue(), retType)
|
|
value := &expression.Constant{Value: v.Datum, RetType: retType}
|
|
initConstantRepertoire(er.sctx.GetEvalCtx(), value)
|
|
er.adjustUTF8MB4Collation(retType)
|
|
if er.err != nil {
|
|
return retNode, false
|
|
}
|
|
er.ctxStackAppend(value, types.EmptyName)
|
|
case *driver.ParamMarkerExpr:
|
|
er.toParamMarker(v)
|
|
case *ast.VariableExpr:
|
|
if v.IsSystem {
|
|
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
|
|
er.rewriteSystemVariable(planCtx, v)
|
|
}, "accessing system variable requires plan context")
|
|
} else {
|
|
er.rewriteUserVariable(v)
|
|
}
|
|
case *ast.FuncCallExpr:
|
|
switch v.FnName.L {
|
|
case ast.Grouping:
|
|
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
|
|
er.funcCallToExpressionWithPlanCtx(planCtx, v)
|
|
}, "grouping function requires plan context")
|
|
default:
|
|
if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok {
|
|
er.tryFoldCounter--
|
|
}
|
|
er.funcCallToExpression(v)
|
|
if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
|
|
er.disableFoldCounter--
|
|
}
|
|
}
|
|
case *ast.TableName:
|
|
er.toTable(v)
|
|
case *ast.ColumnName:
|
|
er.toColumn(v)
|
|
case *ast.UnaryOperationExpr:
|
|
er.unaryOpToExpression(v)
|
|
case *ast.BinaryOperationExpr:
|
|
if v.Op == opcode.LogicAnd || v.Op == opcode.LogicOr {
|
|
er.tryFoldCounter--
|
|
}
|
|
er.binaryOpToExpression(v)
|
|
case *ast.BetweenExpr:
|
|
er.betweenToExpression(v)
|
|
case *ast.CaseExpr:
|
|
if _, ok := expression.TryFoldFunctions["case"]; ok {
|
|
er.tryFoldCounter--
|
|
}
|
|
er.caseToExpression(v)
|
|
if _, ok := expression.DisableFoldFunctions["case"]; ok {
|
|
er.disableFoldCounter--
|
|
}
|
|
case *ast.FuncCastExpr:
|
|
if v.Tp.IsArray() && !er.allowBuildCastArray {
|
|
er.err = expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions")
|
|
return retNode, false
|
|
}
|
|
arg := er.ctxStack[len(er.ctxStack)-1]
|
|
er.err = expression.CheckArgsNotMultiColumnRow(arg)
|
|
if er.err != nil {
|
|
return retNode, false
|
|
}
|
|
|
|
// check the decimal precision of "CAST(AS TIME)".
|
|
er.err = er.checkTimePrecision(v.Tp)
|
|
if er.err != nil {
|
|
return retNode, false
|
|
}
|
|
|
|
castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp, false, v.ExplicitCharSet)
|
|
if err != nil {
|
|
er.err = err
|
|
return retNode, false
|
|
}
|
|
if v.Tp.EvalType() == types.ETString {
|
|
castFunction.SetCoercibility(expression.CoercibilityImplicit)
|
|
if v.Tp.GetCharset() == charset.CharsetASCII {
|
|
castFunction.SetRepertoire(expression.ASCII)
|
|
} else {
|
|
castFunction.SetRepertoire(expression.UNICODE)
|
|
}
|
|
} else {
|
|
castFunction.SetCoercibility(expression.CoercibilityNumeric)
|
|
castFunction.SetRepertoire(expression.ASCII)
|
|
}
|
|
|
|
er.ctxStack[len(er.ctxStack)-1] = castFunction
|
|
er.ctxNameStk[len(er.ctxNameStk)-1] = types.EmptyName
|
|
case *ast.JSONSumCrc32Expr:
|
|
arg := er.ctxStack[len(er.ctxStack)-1]
|
|
jsonSumFunction, err := expression.BuildJSONSumCrc32FunctionWithCheck(er.sctx, arg, v.Tp)
|
|
if err != nil {
|
|
er.err = err
|
|
return retNode, false
|
|
}
|
|
|
|
jsonSumFunction.SetCoercibility(expression.CoercibilityNumeric)
|
|
jsonSumFunction.SetRepertoire(expression.ASCII)
|
|
er.ctxStack[len(er.ctxStack)-1] = jsonSumFunction
|
|
er.ctxNameStk[len(er.ctxNameStk)-1] = types.EmptyName
|
|
case *ast.PatternLikeOrIlikeExpr:
|
|
er.patternLikeOrIlikeToExpression(v)
|
|
case *ast.PatternRegexpExpr:
|
|
er.regexpToScalarFunc(v)
|
|
case *ast.RowExpr:
|
|
er.rowToScalarFunc(v)
|
|
case *ast.PatternInExpr:
|
|
if v.Sel == nil {
|
|
er.inToExpression(len(v.List), v.Not, &v.Type)
|
|
}
|
|
case *ast.PositionExpr:
|
|
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
|
|
er.positionToScalarFunc(planCtx, v)
|
|
}, "")
|
|
case *ast.IsNullExpr:
|
|
er.isNullToExpression(v)
|
|
case *ast.IsTruthExpr:
|
|
er.isTrueToScalarFunc(v)
|
|
case *ast.DefaultExpr:
|
|
if planCtx := er.planCtx; planCtx != nil {
|
|
er.evalDefaultExprWithPlanCtx(planCtx, v)
|
|
} else if er.sourceTable != nil {
|
|
er.evalDefaultExprForTable(v, er.sourceTable)
|
|
} else {
|
|
er.err = errors.Errorf("Unsupported expr %T when source table not provided", v)
|
|
}
|
|
// TODO: Perhaps we don't need to transcode these back to generic integers/strings
|
|
case *ast.TrimDirectionExpr:
|
|
er.ctxStackAppend(&expression.Constant{
|
|
Value: types.NewIntDatum(int64(v.Direction)),
|
|
RetType: types.NewFieldType(mysql.TypeTiny),
|
|
}, types.EmptyName)
|
|
case *ast.TimeUnitExpr:
|
|
er.ctxStackAppend(&expression.Constant{
|
|
Value: types.NewStringDatum(v.Unit.String()),
|
|
RetType: types.NewFieldType(mysql.TypeVarchar),
|
|
}, types.EmptyName)
|
|
case *ast.GetFormatSelectorExpr:
|
|
er.ctxStackAppend(&expression.Constant{
|
|
Value: types.NewStringDatum(v.Selector.String()),
|
|
RetType: types.NewFieldType(mysql.TypeVarchar),
|
|
}, types.EmptyName)
|
|
case *ast.SetCollationExpr:
|
|
arg := er.ctxStack[len(er.ctxStack)-1]
|
|
if collate.NewCollationEnabled() {
|
|
var collInfo *charset.Collation
|
|
// TODO(bb7133): use charset.ValidCharsetAndCollation when its bug is fixed.
|
|
if collInfo, er.err = collate.GetCollationByName(v.Collate); er.err != nil {
|
|
break
|
|
}
|
|
chs := arg.GetType(er.sctx.GetEvalCtx()).GetCharset()
|
|
// if the field is json, the charset is always utf8mb4.
|
|
if arg.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeJSON {
|
|
chs = mysql.UTF8MB4Charset
|
|
}
|
|
if chs != "" && collInfo.CharsetName != chs {
|
|
er.err = charset.ErrCollationCharsetMismatch.GenWithStackByArgs(collInfo.Name, chs)
|
|
break
|
|
}
|
|
}
|
|
// SetCollationExpr sets the collation explicitly, even when the evaluation type of the expression is non-string.
|
|
if _, ok := arg.(*expression.Column); ok || arg.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeJSON {
|
|
if arg.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeEnum || arg.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeSet {
|
|
er.err = plannererrors.ErrNotSupportedYet.GenWithStackByArgs("use collate clause for enum or set")
|
|
break
|
|
}
|
|
// Wrap a cast here to avoid changing the original FieldType of the column expression.
|
|
exprType := arg.GetType(er.sctx.GetEvalCtx()).Clone()
|
|
// if arg type is json, we should cast it to longtext if there is collate clause.
|
|
if arg.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeJSON {
|
|
exprType = types.NewFieldType(mysql.TypeLongBlob)
|
|
exprType.SetCharset(mysql.UTF8MB4Charset)
|
|
}
|
|
exprType.SetCollate(v.Collate)
|
|
casted := expression.BuildCastFunction(er.sctx, arg, exprType)
|
|
arg = casted
|
|
er.ctxStackPop(1)
|
|
er.ctxStackAppend(casted, types.EmptyName)
|
|
} else {
|
|
// For constant and scalar function, we can set its collate directly.
|
|
arg.GetType(er.sctx.GetEvalCtx()).SetCollate(v.Collate)
|
|
}
|
|
er.ctxStack[len(er.ctxStack)-1].SetCoercibility(expression.CoercibilityExplicit)
|
|
er.ctxStack[len(er.ctxStack)-1].SetCharsetAndCollation(arg.GetType(er.sctx.GetEvalCtx()).GetCharset(), arg.GetType(er.sctx.GetEvalCtx()).GetCollate())
|
|
default:
|
|
er.err = errors.Errorf("UnknownType: %T", v)
|
|
return retNode, false
|
|
}
|
|
|
|
if er.err != nil {
|
|
return retNode, false
|
|
}
|
|
return originInNode, true
|
|
}
|
|
|
|
// newFunctionWithInit chooses which expression.NewFunctionImpl() will be used.
|
|
func (er *expressionRewriter) newFunctionWithInit(funcName string, retType *types.FieldType, init expression.ScalarFunctionCallBack, args ...expression.Expression) (ret expression.Expression, err error) {
|
|
if init != nil {
|
|
ret, err = expression.NewFunctionWithInit(er.sctx, funcName, retType, init, args...)
|
|
} else if er.disableFoldCounter > 0 {
|
|
ret, err = expression.NewFunctionBase(er.sctx, funcName, retType, args...)
|
|
} else if er.tryFoldCounter > 0 {
|
|
ret, err = expression.NewFunctionTryFold(er.sctx, funcName, retType, args...)
|
|
} else {
|
|
ret, err = expression.NewFunction(er.sctx, funcName, retType, args...)
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
if scalarFunc, ok := ret.(*expression.ScalarFunction); ok {
|
|
if er.planCtx != nil && er.planCtx.builder != nil {
|
|
er.planCtx.builder.ctx.BuiltinFunctionUsageInc(scalarFunc.Function.PbCode().String())
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// newFunction is being redirected to newFunctionWithInit.
|
|
func (er *expressionRewriter) newFunction(funcName string, retType *types.FieldType, args ...expression.Expression) (ret expression.Expression, err error) {
|
|
return er.newFunctionWithInit(funcName, retType, nil, args...)
|
|
}
|
|
|
|
func (*expressionRewriter) checkTimePrecision(ft *types.FieldType) error {
|
|
if ft.EvalType() == types.ETDuration && ft.GetDecimal() > types.MaxFsp {
|
|
return plannererrors.ErrTooBigPrecision.GenWithStackByArgs(ft.GetDecimal(), "CAST", types.MaxFsp)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (er *expressionRewriter) useCache() bool {
|
|
return er.sctx.IsUseCache()
|
|
}
|
|
|
|
func (er *expressionRewriter) rewriteUserVariable(v *ast.VariableExpr) {
|
|
stkLen := len(er.ctxStack)
|
|
name := strings.ToLower(v.Name)
|
|
evalCtx := er.sctx.GetEvalCtx()
|
|
if v.Value != nil {
|
|
if !evalCtx.GetOptionalPropSet().Contains(exprctx.OptPropSessionVars) {
|
|
er.err = errors.Errorf("rewriting user variable requires '%s' in evalCtx", exprctx.OptPropSessionVars.String())
|
|
return
|
|
}
|
|
|
|
sessionVars, err := expropt.SessionVarsPropReader{}.GetSessionVars(evalCtx)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
|
|
intest.Assert(er.planCtx == nil || sessionVars == er.planCtx.builder.ctx.GetSessionVars())
|
|
|
|
tp := er.ctxStack[stkLen-1].GetType(er.sctx.GetEvalCtx())
|
|
er.ctxStack[stkLen-1], er.err = er.newFunction(ast.SetVar, tp,
|
|
expression.DatumToConstant(types.NewDatum(name), mysql.TypeString, 0),
|
|
er.ctxStack[stkLen-1])
|
|
er.ctxNameStk[stkLen-1] = types.EmptyName
|
|
// Store the field type of the variable into SessionVars.UserVarTypes.
|
|
// Normally we can infer the type from SessionVars.User, but we need SessionVars.UserVarTypes when
|
|
// GetVar has not been executed to fill the SessionVars.Users.
|
|
sessionVars.SetUserVarType(name, tp)
|
|
return
|
|
}
|
|
tp, ok := evalCtx.GetUserVarsReader().GetUserVarType(name)
|
|
if !ok {
|
|
tp = types.NewFieldType(mysql.TypeVarString)
|
|
tp.SetFlen(mysql.MaxFieldVarCharLength)
|
|
}
|
|
f, err := er.newFunction(ast.GetVar, tp, expression.DatumToConstant(types.NewStringDatum(name), mysql.TypeString, 0))
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
f.SetCoercibility(expression.CoercibilityImplicit)
|
|
er.ctxStackAppend(f, types.EmptyName)
|
|
}
|
|
|
|
func (er *expressionRewriter) rewriteSystemVariable(planCtx *exprRewriterPlanCtx, v *ast.VariableExpr) {
|
|
name := strings.ToLower(v.Name)
|
|
sessionVars := planCtx.builder.ctx.GetSessionVars()
|
|
sysVar := variable.GetSysVar(name)
|
|
if sysVar == nil {
|
|
er.err = variable.ErrUnknownSystemVar.FastGenByArgs(name)
|
|
if err := variable.CheckSysVarIsRemoved(name); err != nil {
|
|
// Removed vars still return an error, but we customize it from
|
|
// "unknown" to an explanation of why it is not supported.
|
|
// This is important so users at least know they had the name correct.
|
|
er.err = err
|
|
}
|
|
return
|
|
}
|
|
if sysVar.IsNoop && !vardef.EnableNoopVariables.Load() {
|
|
// The variable does nothing, append a warning to the statement output.
|
|
sessionVars.StmtCtx.AppendWarning(plannererrors.ErrGettingNoopVariable.FastGenByArgs(sysVar.Name))
|
|
}
|
|
if sem.IsEnabled() && sem.IsInvisibleSysVar(sysVar.Name) {
|
|
err := plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("RESTRICTED_VARIABLES_ADMIN")
|
|
planCtx.builder.visitInfo = appendDynamicVisitInfo(planCtx.builder.visitInfo, []string{"RESTRICTED_VARIABLES_ADMIN"}, false, err)
|
|
}
|
|
if v.ExplicitScope && !sysVar.HasNoneScope() {
|
|
if v.IsGlobal && !(sysVar.HasGlobalScope() || sysVar.HasInstanceScope()) {
|
|
er.err = variable.ErrIncorrectScope.GenWithStackByArgs(name, "SESSION")
|
|
return
|
|
}
|
|
if v.IsInstance && !sysVar.HasInstanceScope() {
|
|
er.err = variable.ErrIncorrectScope.GenWithStackByArgs(name, "SESSION or GLOBAL")
|
|
return
|
|
}
|
|
if !v.IsGlobal && !v.IsInstance {
|
|
if !sysVar.HasSessionScope() {
|
|
er.err = variable.ErrIncorrectScope.GenWithStackByArgs(name, "GLOBAL")
|
|
return
|
|
}
|
|
if sysVar.InternalSessionVariable {
|
|
er.err = variable.ErrUnknownSystemVar.GenWithStackByArgs(name)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
var val string
|
|
var err error
|
|
if sysVar.HasNoneScope() {
|
|
val = sysVar.Value
|
|
} else if v.IsGlobal || v.IsInstance {
|
|
val, err = sessionVars.GetGlobalSystemVar(er.ctx, name)
|
|
} else {
|
|
val, err = sessionVars.GetSessionOrGlobalSystemVar(er.ctx, name)
|
|
}
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
nativeVal, nativeType, nativeFlag := sysVar.GetNativeValType(val)
|
|
e := expression.DatumToConstant(nativeVal, nativeType, nativeFlag)
|
|
switch nativeType {
|
|
case mysql.TypeVarString:
|
|
charset, _ := sessionVars.GetSystemVar(vardef.CharacterSetConnection)
|
|
e.GetType(er.sctx.GetEvalCtx()).SetCharset(charset)
|
|
collate, _ := sessionVars.GetSystemVar(vardef.CollationConnection)
|
|
e.GetType(er.sctx.GetEvalCtx()).SetCollate(collate)
|
|
case mysql.TypeLong, mysql.TypeLonglong:
|
|
e.GetType(er.sctx.GetEvalCtx()).SetCharset(charset.CharsetBin)
|
|
e.GetType(er.sctx.GetEvalCtx()).SetCollate(charset.CollationBin)
|
|
default:
|
|
er.err = errors.Errorf("Not supported type(%x) in GetNativeValType() function", nativeType)
|
|
return
|
|
}
|
|
er.ctxStackAppend(e, types.EmptyName)
|
|
}
|
|
|
|
func (er *expressionRewriter) unaryOpToExpression(v *ast.UnaryOperationExpr) {
|
|
stkLen := len(er.ctxStack)
|
|
var op string
|
|
switch v.Op {
|
|
case opcode.Plus:
|
|
// expression (+ a) is equal to a
|
|
return
|
|
case opcode.Minus:
|
|
op = ast.UnaryMinus
|
|
case opcode.BitNeg:
|
|
op = ast.BitNeg
|
|
case opcode.Not, opcode.Not2:
|
|
op = ast.UnaryNot
|
|
default:
|
|
er.err = errors.Errorf("Unknown Unary Op %T", v.Op)
|
|
return
|
|
}
|
|
if expression.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
|
|
er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
|
|
return
|
|
}
|
|
er.ctxStack[stkLen-1], er.err = er.newFunction(op, &v.Type, er.ctxStack[stkLen-1])
|
|
er.ctxNameStk[stkLen-1] = types.EmptyName
|
|
}
|
|
|
|
func (er *expressionRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) {
|
|
stkLen := len(er.ctxStack)
|
|
var function expression.Expression
|
|
switch v.Op {
|
|
case opcode.EQ, opcode.NE, opcode.NullEQ, opcode.GT, opcode.GE, opcode.LT, opcode.LE:
|
|
function, er.err = er.constructBinaryOpFunction(er.ctxStack[stkLen-2], er.ctxStack[stkLen-1],
|
|
v.Op.String())
|
|
default:
|
|
lLen := expression.GetRowLen(er.ctxStack[stkLen-2])
|
|
rLen := expression.GetRowLen(er.ctxStack[stkLen-1])
|
|
if lLen != 1 || rLen != 1 {
|
|
er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
|
|
return
|
|
}
|
|
function, er.err = er.newFunction(v.Op.String(), types.NewFieldType(mysql.TypeUnspecified), er.ctxStack[stkLen-2:]...)
|
|
}
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
er.ctxStackPop(2)
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
}
|
|
|
|
func (er *expressionRewriter) notToExpression(hasNot bool, op string, tp *types.FieldType,
|
|
args ...expression.Expression) expression.Expression {
|
|
opFunc, err := er.newFunction(op, tp, args...)
|
|
if err != nil {
|
|
er.err = err
|
|
return nil
|
|
}
|
|
if !hasNot {
|
|
return opFunc
|
|
}
|
|
|
|
opFunc, err = er.newFunction(ast.UnaryNot, tp, opFunc)
|
|
if err != nil {
|
|
er.err = err
|
|
return nil
|
|
}
|
|
return opFunc
|
|
}
|
|
|
|
func (er *expressionRewriter) isNullToExpression(v *ast.IsNullExpr) {
|
|
stkLen := len(er.ctxStack)
|
|
if expression.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
|
|
er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
|
|
return
|
|
}
|
|
function := er.notToExpression(v.Not, ast.IsNull, &v.Type, er.ctxStack[stkLen-1])
|
|
er.ctxStackPop(1)
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
}
|
|
|
|
func (er *expressionRewriter) positionToScalarFunc(planCtx *exprRewriterPlanCtx, v *ast.PositionExpr) {
|
|
intest.AssertNotNil(planCtx)
|
|
pos := v.N
|
|
str := strconv.Itoa(pos)
|
|
if v.P != nil {
|
|
stkLen := len(er.ctxStack)
|
|
val := er.ctxStack[stkLen-1]
|
|
intNum, isNull, err := expression.GetIntFromConstant(er.sctx.GetEvalCtx(), val)
|
|
str = "?"
|
|
if err == nil {
|
|
if isNull {
|
|
return
|
|
}
|
|
pos = intNum
|
|
er.ctxStackPop(1)
|
|
}
|
|
er.err = err
|
|
}
|
|
if er.err == nil && pos > 0 && pos <= er.schema.Len() && !er.schema.Columns[pos-1].IsHidden {
|
|
er.ctxStackAppend(er.schema.Columns[pos-1], er.names[pos-1])
|
|
} else {
|
|
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(str, clauseMsg[planCtx.builder.curClause])
|
|
}
|
|
}
|
|
|
|
func (er *expressionRewriter) isTrueToScalarFunc(v *ast.IsTruthExpr) {
|
|
stkLen := len(er.ctxStack)
|
|
op := ast.IsTruthWithoutNull
|
|
if v.True == 0 {
|
|
op = ast.IsFalsity
|
|
}
|
|
if expression.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
|
|
er.err = expression.ErrOperandColumns.GenWithStackByArgs(1)
|
|
return
|
|
}
|
|
function := er.notToExpression(v.Not, op, &v.Type, er.ctxStack[stkLen-1])
|
|
er.ctxStackPop(1)
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
}
|
|
|
|
// inToExpression converts in expression to a scalar function. The argument lLen means the length of in list.
|
|
// The argument not means if the expression is not in. The tp stands for the expression type, which is always bool.
|
|
// a in (b, c, d) will be rewritten as `(a = b) or (a = c) or (a = d)`.
|
|
func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.FieldType) {
|
|
stkLen := len(er.ctxStack)
|
|
l := expression.GetRowLen(er.ctxStack[stkLen-lLen-1])
|
|
for i := range lLen {
|
|
if l != expression.GetRowLen(er.ctxStack[stkLen-lLen+i]) {
|
|
er.err = expression.ErrOperandColumns.GenWithStackByArgs(l)
|
|
return
|
|
}
|
|
}
|
|
args := er.ctxStack[stkLen-lLen-1:]
|
|
leftFt := args[0].GetType(er.sctx.GetEvalCtx())
|
|
leftEt, leftIsNull := leftFt.EvalType(), leftFt.GetType() == mysql.TypeNull
|
|
if leftIsNull {
|
|
er.ctxStackPop(lLen + 1)
|
|
er.ctxStackAppend(expression.NewNull(), types.EmptyName)
|
|
return
|
|
}
|
|
if leftEt == types.ETInt {
|
|
for i := 1; i < len(args); i++ {
|
|
if c, ok := args[i].(*expression.Constant); ok {
|
|
var isExceptional bool
|
|
if expression.MaybeOverOptimized4PlanCache(er.sctx, c) {
|
|
if c.GetType(er.sctx.GetEvalCtx()).EvalType() == types.ETInt {
|
|
continue // no need to refine it
|
|
}
|
|
er.sctx.SetSkipPlanCache(fmt.Sprintf("'%v' may be converted to INT", c.StringWithCtx(er.sctx.GetEvalCtx(), errors.RedactLogDisable)))
|
|
if err := expression.RemoveMutableConst(er.sctx, c); err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
}
|
|
args[i], isExceptional = expression.RefineComparedConstant(er.sctx, *leftFt, c, opcode.EQ)
|
|
if isExceptional {
|
|
args[i] = c
|
|
}
|
|
}
|
|
}
|
|
}
|
|
allSameType := true
|
|
for _, arg := range args[1:] {
|
|
if arg.GetType(er.sctx.GetEvalCtx()).GetType() != mysql.TypeNull && expression.GetAccurateCmpType(er.sctx.GetEvalCtx(), args[0], arg) != leftEt {
|
|
allSameType = false
|
|
break
|
|
}
|
|
}
|
|
var function expression.Expression
|
|
if allSameType && l == 1 && lLen > 1 {
|
|
function = er.notToExpression(not, ast.In, tp, er.ctxStack[stkLen-lLen-1:]...)
|
|
} else {
|
|
// If we rewrite IN to EQ, we need to decide what's the collation EQ uses.
|
|
coll := er.deriveCollationForIn(l, lLen, args)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
er.castCollationForIn(l, lLen, stkLen, coll)
|
|
eqFunctions := make([]expression.Expression, 0, lLen)
|
|
for i := stkLen - lLen; i < stkLen; i++ {
|
|
expr, err := er.constructBinaryOpFunction(args[0], er.ctxStack[i], ast.EQ)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
eqFunctions = append(eqFunctions, expr)
|
|
}
|
|
function = expression.ComposeDNFCondition(er.sctx, eqFunctions...)
|
|
if not {
|
|
var err error
|
|
function, err = er.newFunction(ast.UnaryNot, tp, function)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
}
|
|
}
|
|
er.ctxStackPop(lLen + 1)
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
}
|
|
|
|
// deriveCollationForIn derives collation for in expression.
|
|
// We don't handle the cases if the element is a tuple, such as (a, b, c) in ((x1, y1, z1), (x2, y2, z2)).
|
|
func (er *expressionRewriter) deriveCollationForIn(colLen int, _ int, args []expression.Expression) *expression.ExprCollation {
|
|
if colLen == 1 {
|
|
// a in (x, y, z) => coll[0]
|
|
coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...)
|
|
er.err = err
|
|
if er.err != nil {
|
|
return nil
|
|
}
|
|
return coll2
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// castCollationForIn casts collation info for arguments in the `in clause` to make sure the used collation is correct after we
|
|
// rewrite it to equal expression.
|
|
func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen int, coll *expression.ExprCollation) {
|
|
// We don't handle the cases if the element is a tuple, such as (a, b, c) in ((x1, y1, z1), (x2, y2, z2)).
|
|
if colLen != 1 {
|
|
return
|
|
}
|
|
if !collate.NewCollationEnabled() {
|
|
// See https://github.com/pingcap/tidb/issues/52772
|
|
// This function will apply CoercibilityExplicit to the casted expression, but some checks(during ColumnSubstituteImpl) is missed when the new
|
|
// collation is disabled, then lead to panic.
|
|
// To work around this issue, we can skip the function, it should be good since the collation is disabled.
|
|
return
|
|
}
|
|
for i := stkLen - elemCnt; i < stkLen; i++ {
|
|
// todo: consider refining the code and reusing expression.BuildCollationFunction here
|
|
if er.ctxStack[i].GetType(er.sctx.GetEvalCtx()).EvalType() == types.ETString {
|
|
rowFunc, ok := er.ctxStack[i].(*expression.ScalarFunction)
|
|
if ok && rowFunc.FuncName.String() == ast.RowFunc {
|
|
continue
|
|
}
|
|
// Don't convert it if it's charset is binary. So that we don't convert 0x12 to a string.
|
|
if er.ctxStack[i].GetType(er.sctx.GetEvalCtx()).GetCollate() == coll.Collation {
|
|
continue
|
|
}
|
|
tp := er.ctxStack[i].GetType(er.sctx.GetEvalCtx()).Clone()
|
|
if er.ctxStack[i].GetType(er.sctx.GetEvalCtx()).Hybrid() {
|
|
if !(expression.GetAccurateCmpType(er.sctx.GetEvalCtx(), er.ctxStack[stkLen-elemCnt-1], er.ctxStack[i]) == types.ETString) {
|
|
continue
|
|
}
|
|
tp = types.NewFieldType(mysql.TypeVarString)
|
|
} else if coll.Charset == charset.CharsetBin {
|
|
// When cast character string to binary string, if we still use fixed length representation,
|
|
// then 0 padding will be used, which can affect later execution.
|
|
// e.g. https://github.com/pingcap/tidb/pull/35053#pullrequestreview-1008757770 gives an unexpected case.
|
|
// On the other hand, we can not directly return origin expr back,
|
|
// since we need binary collation to do string comparison later.
|
|
// Here we use VarString type of cast, i.e `cast(a as binary)`, to avoid this problem.
|
|
tp.SetType(mysql.TypeVarString)
|
|
}
|
|
tp.SetCharset(coll.Charset)
|
|
tp.SetCollate(coll.Collation)
|
|
er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp)
|
|
er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (er *expressionRewriter) caseToExpression(v *ast.CaseExpr) {
|
|
stkLen := len(er.ctxStack)
|
|
argsLen := 2 * len(v.WhenClauses)
|
|
if v.ElseClause != nil {
|
|
argsLen++
|
|
}
|
|
er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[stkLen-argsLen:]...)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
|
|
// value -> ctxStack[stkLen-argsLen-1]
|
|
// when clause(condition, result) -> ctxStack[stkLen-argsLen:stkLen-1];
|
|
// else clause -> ctxStack[stkLen-1]
|
|
var args []expression.Expression
|
|
if v.Value != nil {
|
|
// args: eq scalar func(args: value, condition1), result1,
|
|
// eq scalar func(args: value, condition2), result2,
|
|
// ...
|
|
// else clause
|
|
value := er.ctxStack[stkLen-argsLen-1]
|
|
args = make([]expression.Expression, 0, argsLen)
|
|
for i := stkLen - argsLen; i < stkLen-1; i += 2 {
|
|
arg, err := er.newFunction(ast.EQ, types.NewFieldType(mysql.TypeTiny), value, er.ctxStack[i])
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
args = append(args, arg)
|
|
args = append(args, er.ctxStack[i+1])
|
|
}
|
|
if v.ElseClause != nil {
|
|
args = append(args, er.ctxStack[stkLen-1])
|
|
}
|
|
argsLen++ // for trimming the value element later
|
|
} else {
|
|
// args: condition1, result1,
|
|
// condition2, result2,
|
|
// ...
|
|
// else clause
|
|
args = er.ctxStack[stkLen-argsLen:]
|
|
}
|
|
function, err := er.newFunction(ast.Case, &v.Type, args...)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
er.ctxStackPop(argsLen)
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
}
|
|
|
|
func (er *expressionRewriter) patternLikeOrIlikeToExpression(v *ast.PatternLikeOrIlikeExpr) {
|
|
l := len(er.ctxStack)
|
|
er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[l-2:]...)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
|
|
char, col := er.sctx.GetCharsetInfo()
|
|
var function expression.Expression
|
|
fieldType := &types.FieldType{}
|
|
isPatternExactMatch := false
|
|
// Treat predicate 'like' or 'ilike' the same way as predicate '=' when it is an exact match and new collation is not enabled.
|
|
if patExpression, ok := er.ctxStack[l-1].(*expression.Constant); ok && !collate.NewCollationEnabled() {
|
|
patString, isNull, err := patExpression.EvalString(er.sctx.GetEvalCtx(), chunk.Row{})
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
if !isNull {
|
|
patValue, patTypes := stringutil.CompilePattern(patString, v.Escape)
|
|
if stringutil.IsExactMatch(patTypes) && er.ctxStack[l-2].GetType(er.sctx.GetEvalCtx()).EvalType() == types.ETString {
|
|
op := ast.EQ
|
|
if v.Not {
|
|
op = ast.NE
|
|
}
|
|
types.DefaultTypeForValue(string(patValue), fieldType, char, col)
|
|
function, er.err = er.constructBinaryOpFunction(er.ctxStack[l-2],
|
|
&expression.Constant{Value: types.NewStringDatum(string(patValue)), RetType: fieldType},
|
|
op)
|
|
isPatternExactMatch = true
|
|
}
|
|
}
|
|
}
|
|
if !isPatternExactMatch {
|
|
funcName := ast.Like
|
|
if !v.IsLike {
|
|
funcName = ast.Ilike
|
|
}
|
|
types.DefaultTypeForValue(int(v.Escape), fieldType, char, col)
|
|
function = er.notToExpression(v.Not, funcName, &v.Type,
|
|
er.ctxStack[l-2], er.ctxStack[l-1], &expression.Constant{Value: types.NewIntDatum(int64(v.Escape)), RetType: fieldType})
|
|
}
|
|
|
|
er.ctxStackPop(2)
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
}
|
|
|
|
func (er *expressionRewriter) regexpToScalarFunc(v *ast.PatternRegexpExpr) {
|
|
l := len(er.ctxStack)
|
|
er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[l-2:]...)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
function := er.notToExpression(v.Not, ast.Regexp, &v.Type, er.ctxStack[l-2], er.ctxStack[l-1])
|
|
er.ctxStackPop(2)
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
}
|
|
|
|
func (er *expressionRewriter) rowToScalarFunc(v *ast.RowExpr) {
|
|
stkLen := len(er.ctxStack)
|
|
length := len(v.Values)
|
|
rows := make([]expression.Expression, 0, length)
|
|
for i := stkLen - length; i < stkLen; i++ {
|
|
rows = append(rows, er.ctxStack[i])
|
|
}
|
|
er.ctxStackPop(length)
|
|
function, err := er.newFunction(ast.RowFunc, rows[0].GetType(er.sctx.GetEvalCtx()), rows...)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
}
|
|
|
|
func (er *expressionRewriter) wrapExpWithCast() (expr, lexp, rexp expression.Expression) {
|
|
stkLen := len(er.ctxStack)
|
|
expr, lexp, rexp = er.ctxStack[stkLen-3], er.ctxStack[stkLen-2], er.ctxStack[stkLen-1]
|
|
var castFunc func(expression.BuildContext, expression.Expression) expression.Expression
|
|
switch expression.ResolveType4Between(er.sctx.GetEvalCtx(), [3]expression.Expression{expr, lexp, rexp}) {
|
|
case types.ETInt:
|
|
expr = expression.WrapWithCastAsInt(er.sctx, expr, nil)
|
|
lexp = expression.WrapWithCastAsInt(er.sctx, lexp, nil)
|
|
rexp = expression.WrapWithCastAsInt(er.sctx, rexp, nil)
|
|
return
|
|
case types.ETReal:
|
|
castFunc = expression.WrapWithCastAsReal
|
|
case types.ETDecimal:
|
|
castFunc = expression.WrapWithCastAsDecimal
|
|
case types.ETString:
|
|
castFunc = func(ctx expression.BuildContext, e expression.Expression) expression.Expression {
|
|
// string kind expression do not need cast
|
|
if e.GetType(er.sctx.GetEvalCtx()).EvalType().IsStringKind() {
|
|
return e
|
|
}
|
|
return expression.WrapWithCastAsString(ctx, e)
|
|
}
|
|
case types.ETDuration:
|
|
expr = expression.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(mysql.TypeDuration))
|
|
lexp = expression.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(mysql.TypeDuration))
|
|
rexp = expression.WrapWithCastAsTime(er.sctx, rexp, types.NewFieldType(mysql.TypeDuration))
|
|
return
|
|
case types.ETDatetime:
|
|
expr = expression.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(mysql.TypeDatetime))
|
|
lexp = expression.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(mysql.TypeDatetime))
|
|
rexp = expression.WrapWithCastAsTime(er.sctx, rexp, types.NewFieldType(mysql.TypeDatetime))
|
|
return
|
|
default:
|
|
return
|
|
}
|
|
|
|
expr = castFunc(er.sctx, expr)
|
|
lexp = castFunc(er.sctx, lexp)
|
|
rexp = castFunc(er.sctx, rexp)
|
|
return
|
|
}
|
|
|
|
func (er *expressionRewriter) betweenToExpression(v *ast.BetweenExpr) {
|
|
stkLen := len(er.ctxStack)
|
|
er.err = expression.CheckArgsNotMultiColumnRow(er.ctxStack[stkLen-3:]...)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
|
|
expr, lexp, rexp := er.wrapExpWithCast()
|
|
|
|
coll, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "BETWEEN", types.ETInt, expr, lexp, rexp)
|
|
er.err = err
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
|
|
// Handle enum or set. We need to know their real type to decide whether to cast them.
|
|
lt := expression.GetAccurateCmpType(er.sctx.GetEvalCtx(), expr, lexp)
|
|
rt := expression.GetAccurateCmpType(er.sctx.GetEvalCtx(), expr, rexp)
|
|
enumOrSetRealTypeIsStr := lt != types.ETInt && rt != types.ETInt
|
|
|
|
expr = expression.BuildCastCollationFunction(er.sctx, expr, coll, enumOrSetRealTypeIsStr)
|
|
lexp = expression.BuildCastCollationFunction(er.sctx, lexp, coll, enumOrSetRealTypeIsStr)
|
|
rexp = expression.BuildCastCollationFunction(er.sctx, rexp, coll, enumOrSetRealTypeIsStr)
|
|
|
|
var l, r expression.Expression
|
|
l, er.err = expression.NewFunction(er.sctx, ast.GE, &v.Type, expr, lexp)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
r, er.err = expression.NewFunction(er.sctx, ast.LE, &v.Type, expr, rexp)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
function, err := er.newFunction(ast.LogicAnd, &v.Type, l, r)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
if v.Not {
|
|
function, err = er.newFunction(ast.UnaryNot, &v.Type, function)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
}
|
|
er.ctxStackPop(3)
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
}
|
|
|
|
// rewriteFuncCall handles a FuncCallExpr and generates a customized function.
|
|
// It should return true if for the given FuncCallExpr a rewrite is performed so that original behavior is skipped.
|
|
// Otherwise it should return false to indicate (the caller) that original behavior needs to be performed.
|
|
func (er *expressionRewriter) rewriteFuncCall(v *ast.FuncCallExpr) bool {
|
|
switch v.FnName.L {
|
|
// when column is not null, ifnull on such column can be optimized to a cast.
|
|
case ast.Ifnull:
|
|
if len(v.Args) != 2 {
|
|
er.err = expression.ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O)
|
|
return true
|
|
}
|
|
stackLen := len(er.ctxStack)
|
|
lhs := er.ctxStack[stackLen-2]
|
|
rhs := er.ctxStack[stackLen-1]
|
|
col, isColumn := lhs.(*expression.Column)
|
|
var isEnumSet bool
|
|
if lhs.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeEnum || lhs.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeSet {
|
|
isEnumSet = true
|
|
}
|
|
// if expr1 is a column with not null flag, then we can optimize it as a cast.
|
|
if isColumn && !isEnumSet && mysql.HasNotNullFlag(col.RetType.GetFlag()) {
|
|
retTp, err := expression.InferType4ControlFuncs(er.sctx, ast.Ifnull, lhs, rhs)
|
|
if err != nil {
|
|
er.err = err
|
|
return true
|
|
}
|
|
retTp.AddFlag((lhs.GetType(er.sctx.GetEvalCtx()).GetFlag() & mysql.NotNullFlag) | (rhs.GetType(er.sctx.GetEvalCtx()).GetFlag() & mysql.NotNullFlag))
|
|
if lhs.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeNull && rhs.GetType(er.sctx.GetEvalCtx()).GetType() == mysql.TypeNull {
|
|
retTp.SetType(mysql.TypeNull)
|
|
retTp.SetFlen(0)
|
|
retTp.SetDecimal(0)
|
|
types.SetBinChsClnFlag(retTp)
|
|
}
|
|
er.ctxStackPop(len(v.Args))
|
|
casted := expression.BuildCastFunction(er.sctx, lhs, retTp)
|
|
er.ctxStackAppend(casted, types.EmptyName)
|
|
return true
|
|
}
|
|
|
|
return false
|
|
case ast.Nullif:
|
|
if len(v.Args) != 2 {
|
|
er.err = expression.ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O)
|
|
return true
|
|
}
|
|
stackLen := len(er.ctxStack)
|
|
param1 := er.ctxStack[stackLen-2]
|
|
param2 := er.ctxStack[stackLen-1]
|
|
// param1 = param2
|
|
funcCompare, err := er.constructBinaryOpFunction(param1, param2, ast.EQ)
|
|
if err != nil {
|
|
er.err = err
|
|
return true
|
|
}
|
|
// NULL
|
|
nullTp := types.NewFieldType(mysql.TypeNull)
|
|
flen, decimal := mysql.GetDefaultFieldLengthAndDecimal(mysql.TypeNull)
|
|
nullTp.SetFlen(flen)
|
|
nullTp.SetDecimal(decimal)
|
|
paramNull := &expression.Constant{
|
|
Value: types.NewDatum(nil),
|
|
RetType: nullTp,
|
|
}
|
|
// if(param1 = param2, NULL, param1)
|
|
funcIf, err := er.newFunction(ast.If, &v.Type, funcCompare, paramNull, param1)
|
|
if err != nil {
|
|
er.err = err
|
|
return true
|
|
}
|
|
er.ctxStackPop(len(v.Args))
|
|
er.ctxStackAppend(funcIf, types.EmptyName)
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (er *expressionRewriter) funcCallToExpressionWithPlanCtx(planCtx *exprRewriterPlanCtx, v *ast.FuncCallExpr) {
|
|
stackLen := len(er.ctxStack)
|
|
args := er.ctxStack[stackLen-len(v.Args):]
|
|
er.err = expression.CheckArgsNotMultiColumnRow(args...)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
|
|
var function expression.Expression
|
|
er.ctxStackPop(len(v.Args))
|
|
switch v.FnName.L {
|
|
case ast.Grouping:
|
|
// grouping function should fetch the underlying grouping-sets meta and rewrite the args here.
|
|
// eg: grouping(a) actually is try to find in which grouping-set that the column 'a' is remained,
|
|
// collecting those gid as a collection and filling it into the grouping function meta. Besides,
|
|
// the first arg of grouping function should be rewritten as gid column defined/passed by Expand
|
|
// from the bottom up.
|
|
intest.AssertNotNil(planCtx)
|
|
if planCtx.rollExpand == nil {
|
|
er.err = plannererrors.ErrInvalidGroupFuncUse
|
|
er.ctxStackAppend(nil, types.EmptyName)
|
|
} else {
|
|
// whether there is some duplicate grouping sets, gpos is only be used in shuffle keys and group keys
|
|
// rather than grouping function.
|
|
// eg: rollup(a,a,b), the decided grouping sets are {a,a,b},{a,a,null},{a,null,null},{null,null,null}
|
|
// for the second and third grouping set: {a,a,null} and {a,null,null}, a here is the col ref of original
|
|
// column `a`. So from the static layer, this two grouping set are equivalent, we don't need to copy col
|
|
// `a double times at the every beginning and resort to gpos to distinguish them.
|
|
// {col-a, col-b, gid, gpos}
|
|
// {a, b, 0, 1}, {a, null, 1, 2}, {a, null, 1, 3}, {null, null, 2, 4}
|
|
// grouping function still only need to care about gid is enough, gpos what group and shuffle keys cared.
|
|
if len(args) > 64 {
|
|
er.err = plannererrors.ErrInvalidNumberOfArgs.GenWithStackByArgs("GROUPING", 64)
|
|
er.ctxStackAppend(nil, types.EmptyName)
|
|
return
|
|
}
|
|
// resolve grouping args in group by items or not.
|
|
resolvedCols, err := planCtx.rollExpand.ResolveGroupingFuncArgsInGroupBy(args)
|
|
if err != nil {
|
|
er.err = err
|
|
er.ctxStackAppend(nil, types.EmptyName)
|
|
return
|
|
}
|
|
newArg := planCtx.rollExpand.GID.Clone()
|
|
init := func(groupingFunc *expression.ScalarFunction) (*expression.ScalarFunction, error) {
|
|
err = groupingFunc.Function.(*expression.BuiltinGroupingImplSig).SetMetadata(planCtx.rollExpand.GroupingMode, planCtx.rollExpand.GenerateGroupingMarks(resolvedCols))
|
|
return groupingFunc, err
|
|
}
|
|
function, er.err = er.newFunctionWithInit(v.FnName.L, &v.Type, init, newArg)
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
}
|
|
default:
|
|
er.err = errors.Errorf("invalid function: %s", v.FnName.L)
|
|
er.ctxStackAppend(nil, types.EmptyName)
|
|
}
|
|
}
|
|
|
|
func (er *expressionRewriter) funcCallToExpression(v *ast.FuncCallExpr) {
|
|
stackLen := len(er.ctxStack)
|
|
args := er.ctxStack[stackLen-len(v.Args):]
|
|
er.err = expression.CheckArgsNotMultiColumnRow(args...)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
|
|
if er.rewriteFuncCall(v) {
|
|
return
|
|
}
|
|
|
|
var function expression.Expression
|
|
er.ctxStackPop(len(v.Args))
|
|
if ok := expression.IsDeferredFunctions(er.sctx, v.FnName.L); er.useCache() && ok {
|
|
// When the expression is unix_timestamp and the number of argument is not zero,
|
|
// we deal with it as normal expression.
|
|
if v.FnName.L == ast.UnixTimestamp && len(v.Args) != 0 {
|
|
function, er.err = er.newFunction(v.FnName.L, &v.Type, args...)
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
} else {
|
|
function, er.err = expression.NewFunctionBase(er.sctx, v.FnName.L, &v.Type, args...)
|
|
c := &expression.Constant{Value: types.NewDatum(nil), RetType: function.GetType(er.sctx.GetEvalCtx()).Clone(), DeferredExpr: function}
|
|
er.ctxStackAppend(c, types.EmptyName)
|
|
}
|
|
} else {
|
|
function, er.err = er.newFunction(v.FnName.L, &v.Type, args...)
|
|
er.ctxStackAppend(function, types.EmptyName)
|
|
}
|
|
}
|
|
|
|
// Now TableName in expression only used by sequence function like nextval(seq).
|
|
// The function arg should be evaluated as a table name rather than normal column name like mysql does.
|
|
func (er *expressionRewriter) toTable(v *ast.TableName) {
|
|
fullName := v.Name.L
|
|
if len(v.Schema.L) != 0 {
|
|
fullName = v.Schema.L + "." + fullName
|
|
}
|
|
val := &expression.Constant{
|
|
Value: types.NewDatum(fullName),
|
|
RetType: types.NewFieldType(mysql.TypeString),
|
|
}
|
|
er.ctxStackAppend(val, types.EmptyName)
|
|
}
|
|
|
|
func (er *expressionRewriter) toParamMarker(v *driver.ParamMarkerExpr) {
|
|
var value *expression.Constant
|
|
value, er.err = expression.ParamMarkerExpression(er.sctx, v, false)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
initConstantRepertoire(er.sctx.GetEvalCtx(), value)
|
|
er.adjustUTF8MB4Collation(value.RetType)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
er.ctxStackAppend(value, types.EmptyName)
|
|
}
|
|
|
|
func (er *expressionRewriter) clause() clauseCode {
|
|
if er.planCtx != nil {
|
|
return er.planCtx.builder.curClause
|
|
}
|
|
return expressionClause
|
|
}
|
|
|
|
func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
|
|
idx, err := expression.FindFieldName(er.names, v)
|
|
if err != nil {
|
|
er.err = plannererrors.ErrAmbiguous.GenWithStackByArgs(v.Name, clauseMsg[fieldList])
|
|
return
|
|
}
|
|
if idx >= 0 {
|
|
column := er.schema.Columns[idx]
|
|
if column.IsHidden {
|
|
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name, clauseMsg[er.clause()])
|
|
return
|
|
}
|
|
er.ctxStackAppend(column, er.names[idx])
|
|
return
|
|
} else if er.planCtx == nil && er.sourceTable != nil &&
|
|
(v.Table.L == "" || er.sourceTable.Name.L == v.Table.L) {
|
|
colInfo := er.sourceTable.FindPublicColumnByName(v.Name.L)
|
|
if colInfo == nil || colInfo.Hidden {
|
|
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name, clauseMsg[er.clause()])
|
|
return
|
|
}
|
|
er.ctxStackAppend(&expression.Column{
|
|
RetType: &colInfo.FieldType,
|
|
ID: colInfo.ID,
|
|
UniqueID: colInfo.ID,
|
|
OrigName: fmt.Sprintf("%s.%s", er.sourceTable.Name.L, colInfo.Name.L),
|
|
}, &types.FieldName{ColName: v.Name})
|
|
return
|
|
}
|
|
|
|
planCtx := er.planCtx
|
|
if planCtx == nil {
|
|
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.String(), clauseMsg[er.clause()])
|
|
return
|
|
}
|
|
|
|
col, name, err := findFieldNameFromNaturalUsingJoin(planCtx.plan, v)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
} else if col != nil {
|
|
er.ctxStackAppend(col, name)
|
|
return
|
|
}
|
|
for i := len(planCtx.builder.outerSchemas) - 1; i >= 0; i-- {
|
|
outerSchema, outerName := planCtx.builder.outerSchemas[i], planCtx.builder.outerNames[i]
|
|
idx, err = expression.FindFieldName(outerName, v)
|
|
if idx >= 0 {
|
|
column := outerSchema.Columns[idx]
|
|
er.ctxStackAppend(&expression.CorrelatedColumn{Column: *column, Data: new(types.Datum)}, outerName[idx])
|
|
return
|
|
}
|
|
if err != nil {
|
|
er.err = plannererrors.ErrAmbiguous.GenWithStackByArgs(v.Name, clauseMsg[fieldList])
|
|
return
|
|
}
|
|
}
|
|
if _, ok := planCtx.plan.(*logicalop.LogicalUnionAll); ok && v.Table.O != "" {
|
|
er.err = plannererrors.ErrTablenameNotAllowedHere.GenWithStackByArgs(v.Table.O, "SELECT", clauseMsg[planCtx.builder.curClause])
|
|
return
|
|
}
|
|
if planCtx.builder.curClause == globalOrderByClause {
|
|
planCtx.builder.curClause = orderByClause
|
|
}
|
|
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.String(), clauseMsg[planCtx.builder.curClause])
|
|
}
|
|
|
|
func findFieldNameFromNaturalUsingJoin(p base.LogicalPlan, v *ast.ColumnName) (col *expression.Column, name *types.FieldName, err error) {
|
|
switch x := p.(type) {
|
|
case *logicalop.LogicalLimit, *logicalop.LogicalSelection, *logicalop.LogicalTopN, *logicalop.LogicalSort, *logicalop.LogicalMaxOneRow:
|
|
return findFieldNameFromNaturalUsingJoin(p.Children()[0], v)
|
|
case *logicalop.LogicalJoin:
|
|
if x.FullSchema != nil {
|
|
idx, err := expression.FindFieldName(x.FullNames, v)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if idx >= 0 {
|
|
return x.FullSchema.Columns[idx], x.FullNames[idx], nil
|
|
}
|
|
}
|
|
}
|
|
return nil, nil, nil
|
|
}
|
|
|
|
func (er *expressionRewriter) evalDefaultExprForTable(v *ast.DefaultExpr, tbl *model.TableInfo) {
|
|
idx, err := expression.FindFieldName(er.names, v.Name)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
if idx < 0 {
|
|
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), "field list")
|
|
return
|
|
}
|
|
name := er.names[idx]
|
|
er.evalFieldDefaultValue(name, tbl)
|
|
}
|
|
|
|
func (er *expressionRewriter) evalDefaultExprWithPlanCtx(planCtx *exprRewriterPlanCtx, v *ast.DefaultExpr) {
|
|
intest.AssertNotNil(planCtx)
|
|
var name *types.FieldName
|
|
// Here we will find the corresponding column for default function. At the same time, we need to consider the issue
|
|
// of subquery and name space.
|
|
// For example, we have two tables t1(a int default 1, b int) and t2(a int default -1, c int). Consider the following SQL:
|
|
// select a from t1 where a > (select default(a) from t2)
|
|
// Refer to the behavior of MySQL, we need to find column a in table t2. If table t2 does not have column a, then find it
|
|
// in table t1. If there are none, return an error message.
|
|
// Based on the above description, we need to look in er.b.allNames from back to front.
|
|
for i := len(planCtx.builder.allNames) - 1; i >= 0; i-- {
|
|
idx, err := expression.FindFieldName(planCtx.builder.allNames[i], v.Name)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
if idx >= 0 {
|
|
name = planCtx.builder.allNames[i][idx]
|
|
break
|
|
}
|
|
}
|
|
if name == nil {
|
|
idx, err := expression.FindFieldName(er.names, v.Name)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
if idx < 0 {
|
|
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), "field list")
|
|
return
|
|
}
|
|
name = er.names[idx]
|
|
}
|
|
|
|
dbName := name.DBName
|
|
if dbName.O == "" {
|
|
// if database name is not specified, use current database name
|
|
dbName = ast.NewCIStr(planCtx.builder.ctx.GetSessionVars().CurrentDB)
|
|
}
|
|
if name.OrigTblName.O == "" {
|
|
// column is evaluated by some expressions, for example:
|
|
// `select default(c) from (select (a+1) as c from t) as t0`
|
|
// in such case, a 'no default' error is returned
|
|
er.err = table.ErrNoDefaultValue.GenWithStackByArgs(name.ColName)
|
|
return
|
|
}
|
|
var tbl table.Table
|
|
tbl, er.err = planCtx.builder.is.TableByName(context.Background(), dbName, name.OrigTblName)
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
er.evalFieldDefaultValue(name, tbl.Meta())
|
|
}
|
|
|
|
func (er *expressionRewriter) evalFieldDefaultValue(field *types.FieldName, tblInfo *model.TableInfo) {
|
|
colName := field.OrigColName.L
|
|
if colName == "" {
|
|
// in some cases, OrigColName is empty, use ColName instead
|
|
colName = field.ColName.L
|
|
}
|
|
col := tblInfo.FindPublicColumnByName(colName)
|
|
if col == nil {
|
|
er.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(colName, "field_list")
|
|
return
|
|
}
|
|
isCurrentTimestamp := hasCurrentDatetimeDefault(col)
|
|
var val *expression.Constant
|
|
switch {
|
|
case isCurrentTimestamp && (col.GetType() == mysql.TypeDatetime || col.GetType() == mysql.TypeTimestamp):
|
|
t, err := expression.GetTimeValue(er.sctx, ast.CurrentTimestamp, col.GetType(), col.GetDecimal(), nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
val = &expression.Constant{
|
|
Value: t,
|
|
RetType: types.NewFieldType(col.GetType()),
|
|
}
|
|
default:
|
|
// for other columns, just use what it is
|
|
d, err := table.GetColDefaultValue(er.sctx, col)
|
|
if err != nil {
|
|
er.err = err
|
|
return
|
|
}
|
|
val = &expression.Constant{
|
|
Value: d,
|
|
RetType: col.FieldType.Clone(),
|
|
}
|
|
}
|
|
if er.err != nil {
|
|
return
|
|
}
|
|
er.ctxStackAppend(val, types.EmptyName)
|
|
}
|
|
|
|
// hasCurrentDatetimeDefault checks if column has current_timestamp default value
|
|
func hasCurrentDatetimeDefault(col *model.ColumnInfo) bool {
|
|
x, ok := col.DefaultValue.(string)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return strings.ToLower(x) == ast.CurrentTimestamp
|
|
}
|
|
|
|
// hasLimit checks if the plan already contains a LIMIT operator
|
|
func hasLimit(plan base.LogicalPlan) bool {
|
|
if plan == nil {
|
|
return false
|
|
}
|
|
// Check if this is a LogicalLimit
|
|
if _, ok := plan.(*logicalop.LogicalLimit); ok {
|
|
return true
|
|
}
|
|
// Recursively check children
|
|
for _, child := range plan.Children() {
|
|
if hasLimit(child) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|