// Copyright 2016 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // See the License for the specific language governing permissions and // limitations under the License. package core import ( "fmt" "math" "math/bits" "reflect" "strings" "unicode" "github.com/cznic/mathutil" "github.com/pingcap/errors" "github.com/pingcap/parser" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" ) const ( // TiDBMergeJoin is hint enforce merge join. TiDBMergeJoin = "tidb_smj" // TiDBIndexNestedLoopJoin is hint enforce index nested loop join. TiDBIndexNestedLoopJoin = "tidb_inlj" // TiDBHashJoin is hint enforce hash join. TiDBHashJoin = "tidb_hj" ) const ( // ErrExprInSelect is in select fields for the error of ErrFieldNotInGroupBy ErrExprInSelect = "SELECT list" // ErrExprInOrderBy is in order by items for the error of ErrFieldNotInGroupBy ErrExprInOrderBy = "ORDER BY" ) func (la *LogicalAggregation) collectGroupByColumns() { la.groupByCols = la.groupByCols[:0] for _, item := range la.GroupByItems { if col, ok := item.(*expression.Column); ok { la.groupByCols = append(la.groupByCols, col) } } } func (b *PlanBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression) (LogicalPlan, map[int]int, error) { b.optFlag = b.optFlag | flagBuildKeyInfo b.optFlag = b.optFlag | flagPushDownAgg // We may apply aggregation eliminate optimization. // So we add the flagMaxMinEliminate to try to convert max/min to topn and flagPushDownTopN to handle the newly added topn operator. b.optFlag = b.optFlag | flagMaxMinEliminate b.optFlag = b.optFlag | flagPushDownTopN // when we eliminate the max and min we may add `is not null` filter. b.optFlag = b.optFlag | flagPredicatePushDown b.optFlag = b.optFlag | flagEliminateAgg b.optFlag = b.optFlag | flagEliminateProjection plan4Agg := LogicalAggregation{AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(aggFuncList))}.Init(b.ctx) schema4Agg := expression.NewSchema(make([]*expression.Column, 0, len(aggFuncList)+p.Schema().Len())...) // aggIdxMap maps the old index to new index after applying common aggregation functions elimination. aggIndexMap := make(map[int]int) for i, aggFunc := range aggFuncList { newArgList := make([]expression.Expression, 0, len(aggFunc.Args)) for _, arg := range aggFunc.Args { newArg, np, err := b.rewrite(arg, p, nil, true) if err != nil { return nil, nil, errors.Trace(err) } p = np newArgList = append(newArgList, newArg) } newFunc := aggregation.NewAggFuncDesc(b.ctx, aggFunc.F, newArgList, aggFunc.Distinct) combined := false for j, oldFunc := range plan4Agg.AggFuncs { if oldFunc.Equal(b.ctx, newFunc) { aggIndexMap[i] = j combined = true break } } if !combined { position := len(plan4Agg.AggFuncs) aggIndexMap[i] = position plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) schema4Agg.Append(&expression.Column{ ColName: model.NewCIStr(fmt.Sprintf("%d_col_%d", plan4Agg.id, position)), UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), IsReferenced: true, RetType: newFunc.RetTp, }) } } for _, col := range p.Schema().Columns { newFunc := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) newCol, _ := col.Clone().(*expression.Column) newCol.RetType = newFunc.RetTp schema4Agg.Append(newCol) } plan4Agg.SetChildren(p) plan4Agg.GroupByItems = gbyItems plan4Agg.SetSchema(schema4Agg) plan4Agg.collectGroupByColumns() return plan4Agg, aggIndexMap, nil } func (b *PlanBuilder) buildResultSetNode(node ast.ResultSetNode) (p LogicalPlan, err error) { switch x := node.(type) { case *ast.Join: return b.buildJoin(x) case *ast.TableSource: switch v := x.Source.(type) { case *ast.SelectStmt: p, err = b.buildSelect(v) case *ast.UnionStmt: p, err = b.buildUnion(v) case *ast.TableName: p, err = b.buildDataSource(v) default: err = ErrUnsupportedType.GenWithStackByArgs(v) } if err != nil { return nil, errors.Trace(err) } if v, ok := p.(*DataSource); ok { v.TableAsName = &x.AsName } for _, col := range p.Schema().Columns { col.OrigTblName = col.TblName if x.AsName.L != "" { col.TblName = x.AsName col.DBName = model.NewCIStr("") } } // Duplicate column name in one table is not allowed. // "select * from (select 1, 1) as a;" is duplicate dupNames := make(map[string]struct{}, len(p.Schema().Columns)) for _, col := range p.Schema().Columns { name := col.ColName.O if _, ok := dupNames[name]; ok { return nil, ErrDupFieldName.GenWithStackByArgs(name) } dupNames[name] = struct{}{} } return p, nil case *ast.SelectStmt: return b.buildSelect(x) case *ast.UnionStmt: return b.buildUnion(x) default: return nil, ErrUnsupportedType.GenWithStack("Unsupported ast.ResultSetNode(%T) for buildResultSetNode()", x) } } // extractOnCondition divide conditions in CNF of join node into 4 groups. // These conditions can be where conditions, join conditions, or collection of both. // If deriveLeft/deriveRight is set, we would try to derive more conditions for left/right plan. func (p *LogicalJoin) extractOnCondition(conditions []expression.Expression, deriveLeft bool, deriveRight bool) (eqCond []*expression.ScalarFunction, leftCond []expression.Expression, rightCond []expression.Expression, otherCond []expression.Expression) { left, right := p.children[0], p.children[1] for _, expr := range conditions { binop, ok := expr.(*expression.ScalarFunction) if ok && len(binop.GetArgs()) == 2 { ctx := binop.GetCtx() arg0, lOK := binop.GetArgs()[0].(*expression.Column) arg1, rOK := binop.GetArgs()[1].(*expression.Column) if lOK && rOK { var leftCol, rightCol *expression.Column if left.Schema().Contains(arg0) && right.Schema().Contains(arg1) { leftCol, rightCol = arg0, arg1 } if leftCol == nil && left.Schema().Contains(arg1) && right.Schema().Contains(arg0) { leftCol, rightCol = arg1, arg0 } if leftCol != nil { // Do not derive `is not null` for anti join, since it may cause wrong results. // For example: // `select * from t t1 where t1.a not in (select b from t t2)` does not imply `t2.b is not null`, // `select * from t t1 where t1.a not in (select a from t t2 where t1.b = t2.b` does not imply `t1.b is not null`, // `select * from t t1 where not exists (select * from t t2 where t2.a = t1.a)` does not imply `t1.a is not null`, if deriveLeft && p.JoinType != AntiSemiJoin { if isNullRejected(ctx, left.Schema(), expr) && !mysql.HasNotNullFlag(leftCol.RetType.Flag) { notNullExpr := expression.BuildNotNullExpr(ctx, leftCol) leftCond = append(leftCond, notNullExpr) } } if deriveRight && p.JoinType != AntiSemiJoin { if isNullRejected(ctx, right.Schema(), expr) && !mysql.HasNotNullFlag(rightCol.RetType.Flag) { notNullExpr := expression.BuildNotNullExpr(ctx, rightCol) rightCond = append(rightCond, notNullExpr) } } } // For quries like `select a in (select a from s where s.b = t.b) from t`, // if subquery is empty caused by `s.b = t.b`, the result should always be // false even if t.a is null or s.a is null. To make this join "empty aware", // we should differentiate `t.a = s.a` from other column equal conditions, so // we put it into OtherConditions instead of EqualConditions of join. if leftCol != nil && binop.FuncName.L == ast.EQ && !leftCol.InOperand && !rightCol.InOperand { cond := expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), leftCol, rightCol) eqCond = append(eqCond, cond.(*expression.ScalarFunction)) continue } } } columns := expression.ExtractColumns(expr) allFromLeft, allFromRight := true, true for _, col := range columns { if !left.Schema().Contains(col) { allFromLeft = false } if !right.Schema().Contains(col) { allFromRight = false } } if allFromRight { rightCond = append(rightCond, expr) } else if allFromLeft { leftCond = append(leftCond, expr) } else { // Relax expr to two supersets: leftRelaxedCond and rightRelaxedCond, the expression now is // `expr AND leftRelaxedCond AND rightRelaxedCond`. Motivation is to push filters down to // children as much as possible. if deriveLeft { leftRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(expr, left.Schema()) if leftRelaxedCond != nil { leftCond = append(leftCond, leftRelaxedCond) } } if deriveRight { rightRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(expr, right.Schema()) if rightRelaxedCond != nil { rightCond = append(rightCond, rightRelaxedCond) } } otherCond = append(otherCond, expr) } } return } func extractTableAlias(p LogicalPlan) *model.CIStr { if p.Schema().Len() > 0 && p.Schema().Columns[0].TblName.L != "" { return &(p.Schema().Columns[0].TblName) } return nil } func (p *LogicalJoin) setPreferredJoinType(hintInfo *tableHintInfo) error { if hintInfo == nil { return nil } lhsAlias := extractTableAlias(p.children[0]) rhsAlias := extractTableAlias(p.children[1]) if hintInfo.ifPreferMergeJoin(lhsAlias, rhsAlias) { p.preferJoinType |= preferMergeJoin } if hintInfo.ifPreferHashJoin(lhsAlias, rhsAlias) { p.preferJoinType |= preferHashJoin } if hintInfo.ifPreferINLJ(lhsAlias) { p.preferJoinType |= preferLeftAsIndexInner } if hintInfo.ifPreferINLJ(rhsAlias) { p.preferJoinType |= preferRightAsIndexInner } // set hintInfo for further usage if this hint info can be used. if p.preferJoinType != 0 { p.hintInfo = hintInfo } // If there're multiple join types and one of them is not index join hint, // then there is a conflict of join types. if bits.OnesCount(p.preferJoinType) > 1 && (p.preferJoinType^preferRightAsIndexInner^preferLeftAsIndexInner) > 0 { return errors.New("Join hints are conflict, you can only specify one type of join") } return nil } func resetNotNullFlag(schema *expression.Schema, start, end int) { for i := start; i < end; i++ { col := *schema.Columns[i] newFieldType := *col.RetType newFieldType.Flag &= ^mysql.NotNullFlag col.RetType = &newFieldType schema.Columns[i] = &col } } func (b *PlanBuilder) buildJoin(joinNode *ast.Join) (LogicalPlan, error) { // We will construct a "Join" node for some statements like "INSERT", // "DELETE", "UPDATE", "REPLACE". For this scenario "joinNode.Right" is nil // and we only build the left "ResultSetNode". if joinNode.Right == nil { return b.buildResultSetNode(joinNode.Left) } b.optFlag = b.optFlag | flagPredicatePushDown leftPlan, err := b.buildResultSetNode(joinNode.Left) if err != nil { return nil, errors.Trace(err) } rightPlan, err := b.buildResultSetNode(joinNode.Right) if err != nil { return nil, errors.Trace(err) } joinPlan := LogicalJoin{StraightJoin: joinNode.StraightJoin || b.inStraightJoin}.Init(b.ctx) joinPlan.SetChildren(leftPlan, rightPlan) joinPlan.SetSchema(expression.MergeSchema(leftPlan.Schema(), rightPlan.Schema())) // Set join type. switch joinNode.Tp { case ast.LeftJoin: // left outer join need to be checked elimination b.optFlag = b.optFlag | flagEliminateOuterJoin joinPlan.JoinType = LeftOuterJoin resetNotNullFlag(joinPlan.schema, leftPlan.Schema().Len(), joinPlan.schema.Len()) case ast.RightJoin: // right outer join need to be checked elimination b.optFlag = b.optFlag | flagEliminateOuterJoin joinPlan.JoinType = RightOuterJoin resetNotNullFlag(joinPlan.schema, 0, leftPlan.Schema().Len()) default: b.optFlag = b.optFlag | flagJoinReOrderGreedy joinPlan.JoinType = InnerJoin } // Merge sub join's redundantSchema into this join plan. When handle query like // select t2.a from (t1 join t2 using (a)) join t3 using (a); // we can simply search in the top level join plan to find redundant column. var lRedundant, rRedundant *expression.Schema if left, ok := leftPlan.(*LogicalJoin); ok && left.redundantSchema != nil { lRedundant = left.redundantSchema } if right, ok := rightPlan.(*LogicalJoin); ok && right.redundantSchema != nil { rRedundant = right.redundantSchema } joinPlan.redundantSchema = expression.MergeSchema(lRedundant, rRedundant) // Set preferred join algorithm if some join hints is specified by user. err = joinPlan.setPreferredJoinType(b.TableHints()) if err != nil { return nil, errors.Trace(err) } // "NATURAL JOIN" doesn't have "ON" or "USING" conditions. // // The "NATURAL [LEFT] JOIN" of two tables is defined to be semantically // equivalent to an "INNER JOIN" or a "LEFT JOIN" with a "USING" clause // that names all columns that exist in both tables. // // See https://dev.mysql.com/doc/refman/5.7/en/join.html for more detail. if joinNode.NaturalJoin { err = b.buildNaturalJoin(joinPlan, leftPlan, rightPlan, joinNode) if err != nil { return nil, errors.Trace(err) } } else if joinNode.Using != nil { err = b.buildUsingClause(joinPlan, leftPlan, rightPlan, joinNode) if err != nil { return nil, errors.Trace(err) } } else if joinNode.On != nil { b.curClause = onClause onExpr, newPlan, err := b.rewrite(joinNode.On.Expr, joinPlan, nil, false) if err != nil { return nil, errors.Trace(err) } if newPlan != joinPlan { return nil, errors.New("ON condition doesn't support subqueries yet") } onCondition := expression.SplitCNFItems(onExpr) joinPlan.attachOnConds(onCondition) } else if joinPlan.JoinType == InnerJoin { // If a inner join without "ON" or "USING" clause, it's a cartesian // product over the join tables. joinPlan.cartesianJoin = true } return joinPlan, nil } // buildUsingClause eliminate the redundant columns and ordering columns based // on the "USING" clause. // // According to the standard SQL, columns are ordered in the following way: // 1. coalesced common columns of "leftPlan" and "rightPlan", in the order they // appears in "leftPlan". // 2. the rest columns in "leftPlan", in the order they appears in "leftPlan". // 3. the rest columns in "rightPlan", in the order they appears in "rightPlan". func (b *PlanBuilder) buildUsingClause(p *LogicalJoin, leftPlan, rightPlan LogicalPlan, join *ast.Join) error { filter := make(map[string]bool, len(join.Using)) for _, col := range join.Using { filter[col.Name.L] = true } return b.coalesceCommonColumns(p, leftPlan, rightPlan, join.Tp == ast.RightJoin, filter) } // buildNaturalJoin builds natural join output schema. It finds out all the common columns // then using the same mechanism as buildUsingClause to eliminate redundant columns and build join conditions. // According to standard SQL, producing this display order: // All the common columns // Every column in the first (left) table that is not a common column // Every column in the second (right) table that is not a common column func (b *PlanBuilder) buildNaturalJoin(p *LogicalJoin, leftPlan, rightPlan LogicalPlan, join *ast.Join) error { return b.coalesceCommonColumns(p, leftPlan, rightPlan, join.Tp == ast.RightJoin, nil) } // coalesceCommonColumns is used by buildUsingClause and buildNaturalJoin. The filter is used by buildUsingClause. func (b *PlanBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan LogicalPlan, rightJoin bool, filter map[string]bool) error { lsc := leftPlan.Schema().Clone() rsc := rightPlan.Schema().Clone() lColumns, rColumns := lsc.Columns, rsc.Columns if rightJoin { lColumns, rColumns = rsc.Columns, lsc.Columns } // Find out all the common columns and put them ahead. commonLen := 0 for i, lCol := range lColumns { for j := commonLen; j < len(rColumns); j++ { if lCol.ColName.L != rColumns[j].ColName.L { continue } if len(filter) > 0 { if !filter[lCol.ColName.L] { break } // Mark this column exist. filter[lCol.ColName.L] = false } col := lColumns[i] copy(lColumns[commonLen+1:i+1], lColumns[commonLen:i]) lColumns[commonLen] = col col = rColumns[j] copy(rColumns[commonLen+1:j+1], rColumns[commonLen:j]) rColumns[commonLen] = col commonLen++ break } } if len(filter) > 0 && len(filter) != commonLen { for col, notExist := range filter { if notExist { return ErrUnknownColumn.GenWithStackByArgs(col, "from clause") } } } schemaCols := make([]*expression.Column, len(lColumns)+len(rColumns)-commonLen) copy(schemaCols[:len(lColumns)], lColumns) copy(schemaCols[len(lColumns):], rColumns[commonLen:]) conds := make([]expression.Expression, 0, commonLen) for i := 0; i < commonLen; i++ { lc, rc := lsc.Columns[i], rsc.Columns[i] cond, err := expression.NewFunction(b.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lc, rc) if err != nil { return errors.Trace(err) } conds = append(conds, cond) } p.SetSchema(expression.NewSchema(schemaCols...)) p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(rColumns[:commonLen]...)) p.OtherConditions = append(conds, p.OtherConditions...) return nil } func (b *PlanBuilder) buildSelection(p LogicalPlan, where ast.ExprNode, AggMapper map[*ast.AggregateFuncExpr]int) (LogicalPlan, error) { b.optFlag = b.optFlag | flagPredicatePushDown if b.curClause != havingClause { b.curClause = whereClause } conditions := splitWhere(where) expressions := make([]expression.Expression, 0, len(conditions)) selection := LogicalSelection{}.Init(b.ctx) for _, cond := range conditions { expr, np, err := b.rewrite(cond, p, AggMapper, false) if err != nil { return nil, errors.Trace(err) } p = np if expr == nil { continue } cnfItems := expression.SplitCNFItems(expr) for _, item := range cnfItems { if con, ok := item.(*expression.Constant); ok && con.DeferredExpr == nil { ret, _, err := expression.EvalBool(b.ctx, expression.CNFExprs{con}, chunk.Row{}) if err != nil || ret { continue } // If there is condition which is always false, return dual plan directly. dual := LogicalTableDual{}.Init(b.ctx) dual.SetSchema(p.Schema()) return dual, nil } expressions = append(expressions, item) } } if len(expressions) == 0 { return p, nil } selection.Conditions = expressions selection.SetChildren(p) return selection, nil } // buildProjectionFieldNameFromColumns builds the field name, table name and database name when field expression is a column reference. func (b *PlanBuilder) buildProjectionFieldNameFromColumns(field *ast.SelectField, c *expression.Column) (colName, origColName, tblName, origTblName, dbName model.CIStr) { if astCol, ok := getInnerFromParenthesesAndUnaryPlus(field.Expr).(*ast.ColumnNameExpr); ok { origColName, tblName, dbName = astCol.Name.Name, astCol.Name.Table, astCol.Name.Schema } if field.AsName.L != "" { colName = field.AsName } else { colName = origColName } if tblName.L == "" { tblName = c.TblName } if dbName.L == "" { dbName = c.DBName } return colName, origColName, tblName, c.OrigTblName, c.DBName } // buildProjectionFieldNameFromExpressions builds the field name when field expression is a normal expression. func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectField) (model.CIStr, error) { if agg, ok := field.Expr.(*ast.AggregateFuncExpr); ok && agg.F == ast.AggFuncFirstRow { // When the query is select t.a from t group by a; The Column Name should be a but not t.a; return agg.Args[0].(*ast.ColumnNameExpr).Name.Name, nil } innerExpr := getInnerFromParenthesesAndUnaryPlus(field.Expr) funcCall, isFuncCall := innerExpr.(*ast.FuncCallExpr) // When used to produce a result set column, NAME_CONST() causes the column to have the given name. // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_name-const for details if isFuncCall && funcCall.FnName.L == ast.NameConst { if v, err := evalAstExpr(b.ctx, funcCall.Args[0]); err == nil { if s, err := v.ToString(); err == nil { return model.NewCIStr(s), nil } } return model.NewCIStr(""), ErrWrongArguments.GenWithStackByArgs("NAME_CONST") } valueExpr, isValueExpr := innerExpr.(*driver.ValueExpr) // Non-literal: Output as inputed, except that comments need to be removed. if !isValueExpr { return model.NewCIStr(parser.SpecFieldPattern.ReplaceAllStringFunc(field.Text(), parser.TrimComment)), nil } // Literal: Need special processing switch valueExpr.Kind() { case types.KindString: projName := valueExpr.GetString() projOffset := valueExpr.GetProjectionOffset() if projOffset >= 0 { projName = projName[:projOffset] } // See #3686, #3994: // For string literals, string content is used as column name. Non-graph initial characters are trimmed. fieldName := strings.TrimLeftFunc(projName, func(r rune) bool { return !unicode.IsOneOf(mysql.RangeGraph, r) }) return model.NewCIStr(fieldName), nil case types.KindNull: // See #4053, #3685 return model.NewCIStr("NULL"), nil default: // Keep as it is. if innerExpr.Text() != "" { return model.NewCIStr(innerExpr.Text()), nil } return model.NewCIStr(field.Text()), nil } } // buildProjectionField builds the field object according to SelectField in projection. func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectField, expr expression.Expression) (*expression.Column, error) { var origTblName, tblName, origColName, colName, dbName model.CIStr if c, ok := expr.(*expression.Column); ok && !c.IsReferenced { // Field is a column reference. colName, origColName, tblName, origTblName, dbName = b.buildProjectionFieldNameFromColumns(field, c) } else if field.AsName.L != "" { // Field has alias. colName = field.AsName } else { // Other: field is an expression. var err error if colName, err = b.buildProjectionFieldNameFromExpressions(field); err != nil { return nil, errors.Trace(err) } } return &expression.Column{ UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), TblName: tblName, OrigTblName: origTblName, ColName: colName, OrigColName: origColName, DBName: dbName, RetType: expr.GetType(), }, nil } // buildProjection returns a Projection plan and non-aux columns length. func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, considerWindow bool) (LogicalPlan, int, error) { b.optFlag |= flagEliminateProjection b.curClause = fieldList proj := LogicalProjection{Exprs: make([]expression.Expression, 0, len(fields))}.Init(b.ctx) schema := expression.NewSchema(make([]*expression.Column, 0, len(fields))...) oldLen := 0 for i, field := range fields { if !field.Auxiliary { oldLen++ } isWindowFuncField := ast.HasWindowFlag(field.Expr) // Although window functions occurs in the select fields, but it has to be processed after having clause. // So when we build the projection for select fields, we need to skip the window function. // When `considerWindow` is false, we will only build fields for non-window functions, so we add fake placeholders. // for window functions. These fake placeholders will be erased in column pruning. // When `considerWindow` is true, all the non-window fields have been built, so we just use the schema columns. if considerWindow && !isWindowFuncField { col := p.Schema().Columns[i] proj.Exprs = append(proj.Exprs, col) schema.Append(col) continue } else if !considerWindow && isWindowFuncField { expr := expression.Zero proj.Exprs = append(proj.Exprs, expr) col, err := b.buildProjectionField(proj.id, schema.Len()+1, field, expr) if err != nil { return nil, 0, errors.Trace(err) } schema.Append(col) continue } newExpr, np, err := b.rewrite(field.Expr, p, mapper, true) if err != nil { return nil, 0, errors.Trace(err) } p = np proj.Exprs = append(proj.Exprs, newExpr) col, err := b.buildProjectionField(proj.id, schema.Len()+1, field, newExpr) if err != nil { return nil, 0, errors.Trace(err) } schema.Append(col) } proj.SetSchema(schema) proj.SetChildren(p) return proj, oldLen, nil } func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggregation { b.optFlag = b.optFlag | flagBuildKeyInfo b.optFlag = b.optFlag | flagPushDownAgg plan4Agg := LogicalAggregation{ AggFuncs: make([]*aggregation.AggFuncDesc, 0, child.Schema().Len()), GroupByItems: expression.Column2Exprs(child.Schema().Clone().Columns[:length]), }.Init(b.ctx) plan4Agg.collectGroupByColumns() for _, col := range child.Schema().Columns { aggDesc := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, aggDesc) } plan4Agg.SetChildren(child) plan4Agg.SetSchema(child.Schema().Clone()) // Distinct will be rewritten as first_row, we reset the type here since the return type // of first_row is not always the same as the column arg of first_row. for i, col := range plan4Agg.schema.Columns { col.RetType = plan4Agg.AggFuncs[i].RetTp } return plan4Agg } // unionJoinFieldType finds the type which can carry the given types in Union. func unionJoinFieldType(a, b *types.FieldType) *types.FieldType { resultTp := types.NewFieldType(types.MergeFieldType(a.Tp, b.Tp)) // This logic will be intelligible when it is associated with the buildProjection4Union logic. if resultTp.Tp == mysql.TypeNewDecimal { // The decimal result type will be unsigned only when all the decimals to be united are unsigned. resultTp.Flag &= b.Flag & mysql.UnsignedFlag } else { // Non-decimal results will be unsigned when the first SQL statement result in the union is unsigned. resultTp.Flag |= a.Flag & mysql.UnsignedFlag } resultTp.Decimal = mathutil.Max(a.Decimal, b.Decimal) // `Flen - Decimal` is the fraction before '.' resultTp.Flen = mathutil.Max(a.Flen-a.Decimal, b.Flen-b.Decimal) + resultTp.Decimal if resultTp.EvalType() != types.ETInt && (a.EvalType() == types.ETInt || b.EvalType() == types.ETInt) && resultTp.Flen < mysql.MaxIntWidth { resultTp.Flen = mysql.MaxIntWidth } resultTp.Charset = a.Charset resultTp.Collate = a.Collate expression.SetBinFlagOrBinStr(b, resultTp) return resultTp } func (b *PlanBuilder) buildProjection4Union(u *LogicalUnionAll) { unionCols := make([]*expression.Column, 0, u.children[0].Schema().Len()) // Infer union result types by its children's schema. for i, col := range u.children[0].Schema().Columns { resultTp := col.RetType for j := 1; j < len(u.children); j++ { childTp := u.children[j].Schema().Columns[i].RetType resultTp = unionJoinFieldType(resultTp, childTp) } unionCols = append(unionCols, &expression.Column{ ColName: col.ColName, RetType: resultTp, UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), }) } u.schema = expression.NewSchema(unionCols...) // Process each child and add a projection above original child. // So the schema of `UnionAll` can be the same with its children's. for childID, child := range u.children { exprs := make([]expression.Expression, len(child.Schema().Columns)) for i, srcCol := range child.Schema().Columns { dstType := unionCols[i].RetType srcType := srcCol.RetType if !srcType.Equal(dstType) { exprs[i] = expression.BuildCastFunction4Union(b.ctx, srcCol, dstType) } else { exprs[i] = srcCol } } b.optFlag |= flagEliminateProjection proj := LogicalProjection{Exprs: exprs, avoidColumnEvaluator: true}.Init(b.ctx) proj.SetSchema(u.schema.Clone()) proj.SetChildren(child) u.children[childID] = proj } } func (b *PlanBuilder) buildUnion(union *ast.UnionStmt) (LogicalPlan, error) { distinctSelectPlans, allSelectPlans, err := b.divideUnionSelectPlans(union.SelectList.Selects) if err != nil { return nil, errors.Trace(err) } unionDistinctPlan := b.buildUnionAll(distinctSelectPlans) if unionDistinctPlan != nil { unionDistinctPlan = b.buildDistinct(unionDistinctPlan, unionDistinctPlan.Schema().Len()) if len(allSelectPlans) > 0 { // Can't change the statements order in order to get the correct column info. allSelectPlans = append([]LogicalPlan{unionDistinctPlan}, allSelectPlans...) } } unionAllPlan := b.buildUnionAll(allSelectPlans) unionPlan := unionDistinctPlan if unionAllPlan != nil { unionPlan = unionAllPlan } oldLen := unionPlan.Schema().Len() if union.OrderBy != nil { unionPlan, err = b.buildSort(unionPlan, union.OrderBy.Items, nil) if err != nil { return nil, errors.Trace(err) } } if union.Limit != nil { unionPlan, err = b.buildLimit(unionPlan, union.Limit) if err != nil { return nil, errors.Trace(err) } } // Fix issue #8189 (https://github.com/pingcap/tidb/issues/8189). // If there are extra expressions generated from `ORDER BY` clause, generate a `Projection` to remove them. if oldLen != unionPlan.Schema().Len() { proj := LogicalProjection{Exprs: expression.Column2Exprs(unionPlan.Schema().Columns[:oldLen])}.Init(b.ctx) proj.SetChildren(unionPlan) schema := expression.NewSchema(unionPlan.Schema().Clone().Columns[:oldLen]...) for _, col := range schema.Columns { col.UniqueID = b.ctx.GetSessionVars().AllocPlanColumnID() } proj.SetSchema(schema) return proj, nil } return unionPlan, nil } // divideUnionSelectPlans resolves union's select stmts to logical plans. // and divide result plans into "union-distinct" and "union-all" parts. // divide rule ref: https://dev.mysql.com/doc/refman/5.7/en/union.html // "Mixed UNION types are treated such that a DISTINCT union overrides any ALL union to its left." func (b *PlanBuilder) divideUnionSelectPlans(selects []*ast.SelectStmt) (distinctSelects []LogicalPlan, allSelects []LogicalPlan, err error) { firstUnionAllIdx, columnNums := 0, -1 // The last slot is reserved for appending distinct union outside this function. children := make([]LogicalPlan, len(selects), len(selects)+1) for i := len(selects) - 1; i >= 0; i-- { stmt := selects[i] if firstUnionAllIdx == 0 && stmt.IsAfterUnionDistinct { firstUnionAllIdx = i + 1 } selectPlan, err := b.buildSelect(stmt) if err != nil { return nil, nil, errors.Trace(err) } if columnNums == -1 { columnNums = selectPlan.Schema().Len() } if selectPlan.Schema().Len() != columnNums { return nil, nil, ErrWrongNumberOfColumnsInSelect.GenWithStackByArgs() } children[i] = selectPlan } return children[:firstUnionAllIdx], children[firstUnionAllIdx:], nil } func (b *PlanBuilder) buildUnionAll(subPlan []LogicalPlan) LogicalPlan { if len(subPlan) == 0 { return nil } u := LogicalUnionAll{}.Init(b.ctx) u.children = subPlan b.buildProjection4Union(u) return u } // ByItems wraps a "by" item. type ByItems struct { Expr expression.Expression Desc bool } // String implements fmt.Stringer interface. func (by *ByItems) String() string { if by.Desc { return fmt.Sprintf("%s true", by.Expr) } return by.Expr.String() } // Clone makes a copy of ByItems. func (by *ByItems) Clone() *ByItems { return &ByItems{Expr: by.Expr.Clone(), Desc: by.Desc} } // itemTransformer transforms ParamMarkerExpr to PositionExpr in the context of ByItem type itemTransformer struct { } func (t *itemTransformer) Enter(inNode ast.Node) (ast.Node, bool) { switch n := inNode.(type) { case *driver.ParamMarkerExpr: newNode := expression.ConstructPositionExpr(n) return newNode, true } return inNode, false } func (t *itemTransformer) Leave(inNode ast.Node) (ast.Node, bool) { return inNode, false } func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int) (*LogicalSort, error) { if _, isUnion := p.(*LogicalUnionAll); isUnion { b.curClause = globalOrderByClause } else { b.curClause = orderByClause } sort := LogicalSort{}.Init(b.ctx) exprs := make([]*ByItems, 0, len(byItems)) transformer := &itemTransformer{} for _, item := range byItems { newExpr, _ := item.Expr.Accept(transformer) item.Expr = newExpr.(ast.ExprNode) it, np, err := b.rewrite(item.Expr, p, aggMapper, true) if err != nil { return nil, errors.Trace(err) } p = np exprs = append(exprs, &ByItems{Expr: it, Desc: item.Desc}) } sort.ByItems = exprs sort.SetChildren(p) return sort, nil } // getUintFromNode gets uint64 value from ast.Node. // For ordinary statement, node should be uint64 constant value. // For prepared statement, node is string. We should convert it to uint64. func getUintFromNode(ctx sessionctx.Context, n ast.Node) (uVal uint64, isNull bool, isExpectedType bool) { var val interface{} switch v := n.(type) { case *driver.ValueExpr: val = v.GetValue() case *driver.ParamMarkerExpr: param, err := expression.GetParamExpression(ctx, v) if err != nil { return 0, false, false } str, isNull, err := expression.GetStringFromConstant(ctx, param) if err != nil { return 0, false, false } if isNull { return 0, true, true } val = str default: return 0, false, false } switch v := val.(type) { case uint64: return v, false, true case int64: if v >= 0 { return uint64(v), false, true } case string: sc := ctx.GetSessionVars().StmtCtx uVal, err := types.StrToUint(sc, v) if err != nil { return 0, false, false } return uVal, false, true } return 0, false, false } func extractLimitCountOffset(ctx sessionctx.Context, limit *ast.Limit) (count uint64, offset uint64, err error) { var isExpectedType bool if limit.Count != nil { count, _, isExpectedType = getUintFromNode(ctx, limit.Count) if !isExpectedType { return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT") } } if limit.Offset != nil { offset, _, isExpectedType = getUintFromNode(ctx, limit.Offset) if !isExpectedType { return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT") } } return count, offset, nil } func (b *PlanBuilder) buildLimit(src LogicalPlan, limit *ast.Limit) (LogicalPlan, error) { b.optFlag = b.optFlag | flagPushDownTopN var ( offset, count uint64 err error ) if count, offset, err = extractLimitCountOffset(b.ctx, limit); err != nil { return nil, err } if count > math.MaxUint64-offset { count = math.MaxUint64 - offset } if offset+count == 0 { tableDual := LogicalTableDual{RowCount: 0}.Init(b.ctx) tableDual.schema = src.Schema() return tableDual, nil } li := LogicalLimit{ Offset: offset, Count: count, }.Init(b.ctx) li.SetChildren(src) return li, nil } // colMatch means that if a match b, e.g. t.a can match test.t.a but test.t.a can't match t.a. // Because column a want column from database test exactly. func colMatch(a *ast.ColumnName, b *ast.ColumnName) bool { if a.Schema.L == "" || a.Schema.L == b.Schema.L { if a.Table.L == "" || a.Table.L == b.Table.L { return a.Name.L == b.Name.L } } return false } func matchField(f *ast.SelectField, col *ast.ColumnNameExpr, ignoreAsName bool) bool { // if col specify a table name, resolve from table source directly. if col.Name.Table.L == "" { if f.AsName.L == "" || ignoreAsName { if curCol, isCol := f.Expr.(*ast.ColumnNameExpr); isCol { return curCol.Name.Name.L == col.Name.Name.L } else if _, isFunc := f.Expr.(*ast.FuncCallExpr); isFunc { // Fix issue 7331 // If there are some function calls in SelectField, we check if // ColumnNameExpr in GroupByClause matches one of these function calls. // Example: select concat(k1,k2) from t group by `concat(k1,k2)`, // `concat(k1,k2)` matches with function call concat(k1, k2). return strings.ToLower(f.Text()) == col.Name.Name.L } // a expression without as name can't be matched. return false } return f.AsName.L == col.Name.Name.L } return false } func resolveFromSelectFields(v *ast.ColumnNameExpr, fields []*ast.SelectField, ignoreAsName bool) (index int, err error) { var matchedExpr ast.ExprNode index = -1 for i, field := range fields { if field.Auxiliary { continue } if matchField(field, v, ignoreAsName) { curCol, isCol := field.Expr.(*ast.ColumnNameExpr) if !isCol { return i, nil } if matchedExpr == nil { matchedExpr = curCol index = i } else if !colMatch(matchedExpr.(*ast.ColumnNameExpr).Name, curCol.Name) && !colMatch(curCol.Name, matchedExpr.(*ast.ColumnNameExpr).Name) { return -1, ErrAmbiguous.GenWithStackByArgs(curCol.Name.Name.L, clauseMsg[fieldList]) } } } return } // havingWindowAndOrderbyExprResolver visits Expr tree. // It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr. type havingWindowAndOrderbyExprResolver struct { inAggFunc bool inWindowFunc bool inExpr bool orderBy bool err error p LogicalPlan selectFields []*ast.SelectField aggMapper map[*ast.AggregateFuncExpr]int colMapper map[*ast.ColumnNameExpr]int gbyItems []*ast.ByItem outerSchemas []*expression.Schema curClause clauseCode } // Enter implements Visitor interface. func (a *havingWindowAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChildren bool) { switch n.(type) { case *ast.AggregateFuncExpr: a.inAggFunc = true case *ast.WindowFuncExpr: a.inWindowFunc = true case *driver.ParamMarkerExpr, *ast.ColumnNameExpr, *ast.ColumnName: case *ast.SubqueryExpr, *ast.ExistsSubqueryExpr: // Enter a new context, skip it. // For example: select sum(c) + c + exists(select c from t) from t; return n, true default: a.inExpr = true } return n, false } func (a *havingWindowAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, schema *expression.Schema) (int, error) { col, err := schema.FindColumn(v.Name) if err != nil { return -1, errors.Trace(err) } if col == nil { return -1, nil } newColName := &ast.ColumnName{ Schema: col.DBName, Table: col.TblName, Name: col.ColName, } for i, field := range a.selectFields { if c, ok := field.Expr.(*ast.ColumnNameExpr); ok && colMatch(c.Name, newColName) { return i, nil } } sf := &ast.SelectField{ Expr: &ast.ColumnNameExpr{Name: newColName}, Auxiliary: true, } sf.Expr.SetType(col.GetType()) a.selectFields = append(a.selectFields, sf) return len(a.selectFields) - 1, nil } // Leave implements Visitor interface. func (a *havingWindowAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool) { switch v := n.(type) { case *ast.AggregateFuncExpr: a.inAggFunc = false a.aggMapper[v] = len(a.selectFields) a.selectFields = append(a.selectFields, &ast.SelectField{ Auxiliary: true, Expr: v, AsName: model.NewCIStr(fmt.Sprintf("sel_agg_%d", len(a.selectFields))), }) case *ast.WindowFuncExpr: a.inWindowFunc = false if a.curClause == havingClause { a.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(v.F) return node, false } case *ast.ColumnNameExpr: resolveFieldsFirst := true if a.inAggFunc || a.inWindowFunc || (a.orderBy && a.inExpr) { resolveFieldsFirst = false } if !a.inAggFunc && !a.orderBy { for _, item := range a.gbyItems { if col, ok := item.Expr.(*ast.ColumnNameExpr); ok && (colMatch(v.Name, col.Name) || colMatch(col.Name, v.Name)) { resolveFieldsFirst = false break } } } var index int if resolveFieldsFirst { index, a.err = resolveFromSelectFields(v, a.selectFields, false) if a.err != nil { return node, false } if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { a.err = ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) return node, false } if index == -1 { if a.orderBy { index, a.err = a.resolveFromSchema(v, a.p.Schema()) } else { index, a.err = resolveFromSelectFields(v, a.selectFields, true) } } } else { // We should ignore the err when resolving from schema. Because we could resolve successfully // when considering select fields. var err error index, err = a.resolveFromSchema(v, a.p.Schema()) _ = err if index == -1 && a.curClause != windowClause { index, a.err = resolveFromSelectFields(v, a.selectFields, false) if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { a.err = ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) return node, false } } } if a.err != nil { return node, false } if index == -1 { // If we can't find it any where, it may be a correlated columns. for _, schema := range a.outerSchemas { col, err1 := schema.FindColumn(v.Name) if err1 != nil { a.err = errors.Trace(err1) return node, false } if col != nil { return n, true } } a.err = ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), clauseMsg[a.curClause]) return node, false } if a.inAggFunc { return a.selectFields[index].Expr, true } a.colMapper[v] = index } return n, true } // resolveHavingAndOrderBy will process aggregate functions and resolve the columns that don't exist in select fields. // If we found some columns that are not in select fields, we will append it to select fields and update the colMapper. // When we rewrite the order by / having expression, we will find column in map at first. func (b *PlanBuilder) resolveHavingAndOrderBy(sel *ast.SelectStmt, p LogicalPlan) ( map[*ast.AggregateFuncExpr]int, map[*ast.AggregateFuncExpr]int, error) { extractor := &havingWindowAndOrderbyExprResolver{ p: p, selectFields: sel.Fields.Fields, aggMapper: make(map[*ast.AggregateFuncExpr]int), colMapper: b.colMapper, outerSchemas: b.outerSchemas, } if sel.GroupBy != nil { extractor.gbyItems = sel.GroupBy.Items } // Extract agg funcs from having clause. if sel.Having != nil { extractor.curClause = havingClause n, ok := sel.Having.Expr.Accept(extractor) if !ok { return nil, nil, errors.Trace(extractor.err) } sel.Having.Expr = n.(ast.ExprNode) } havingAggMapper := extractor.aggMapper extractor.aggMapper = make(map[*ast.AggregateFuncExpr]int) extractor.orderBy = true extractor.inExpr = false // Extract agg funcs from order by clause. if sel.OrderBy != nil { extractor.curClause = orderByClause for _, item := range sel.OrderBy.Items { n, ok := item.Expr.Accept(extractor) if !ok { return nil, nil, errors.Trace(extractor.err) } item.Expr = n.(ast.ExprNode) } } sel.Fields.Fields = extractor.selectFields return havingAggMapper, extractor.aggMapper, nil } func (b *PlanBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.AggregateFuncExpr, map[*ast.AggregateFuncExpr]int) { extractor := &AggregateFuncExtractor{} for _, f := range fields { n, _ := f.Expr.Accept(extractor) f.Expr = n.(ast.ExprNode) } aggList := extractor.AggFuncs totalAggMapper := make(map[*ast.AggregateFuncExpr]int) for i, agg := range aggList { totalAggMapper[agg] = i } return aggList, totalAggMapper } // resolveWindowFunction will process window functions and resolve the columns that don't exist in select fields. func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p LogicalPlan) ( map[*ast.AggregateFuncExpr]int, error) { extractor := &havingWindowAndOrderbyExprResolver{ p: p, selectFields: sel.Fields.Fields, aggMapper: make(map[*ast.AggregateFuncExpr]int), colMapper: b.colMapper, outerSchemas: b.outerSchemas, } extractor.curClause = windowClause for _, field := range sel.Fields.Fields { if !ast.HasWindowFlag(field.Expr) { continue } n, ok := field.Expr.Accept(extractor) if !ok { return nil, extractor.err } field.Expr = n.(ast.ExprNode) } sel.Fields.Fields = extractor.selectFields return extractor.aggMapper, nil } // gbyResolver resolves group by items from select fields. type gbyResolver struct { ctx sessionctx.Context fields []*ast.SelectField schema *expression.Schema err error inExpr bool isParam bool } func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { switch n := inNode.(type) { case *ast.SubqueryExpr, *ast.CompareSubqueryExpr, *ast.ExistsSubqueryExpr: return inNode, true case *driver.ParamMarkerExpr: newNode := expression.ConstructPositionExpr(n) g.isParam = true return newNode, true case *driver.ValueExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.ColumnName: default: g.inExpr = true } return inNode, false } func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { extractor := &AggregateFuncExtractor{} switch v := inNode.(type) { case *ast.ColumnNameExpr: col, err := g.schema.FindColumn(v.Name) if col == nil || !g.inExpr { var index int index, g.err = resolveFromSelectFields(v, g.fields, false) if g.err != nil { return inNode, false } if col != nil { return inNode, true } if index != -1 { ret := g.fields[index].Expr ret.Accept(extractor) if len(extractor.AggFuncs) != 0 { err = ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to group function") } else if ast.HasWindowFlag(ret) { err = ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to window function") } else { return ret, true } } g.err = errors.Trace(err) return inNode, false } case *ast.PositionExpr: pos, isNull, err := expression.PosFromPositionExpr(g.ctx, v) if err != nil { g.err = ErrUnknown.GenWithStackByArgs() } if err != nil || isNull { return inNode, false } if pos < 1 || pos > len(g.fields) { g.err = errors.Errorf("Unknown column '%d' in 'group statement'", pos) return inNode, false } ret := g.fields[pos-1].Expr ret.Accept(extractor) if len(extractor.AggFuncs) != 0 { g.err = ErrWrongGroupField.GenWithStackByArgs(g.fields[pos-1].Text()) return inNode, false } return ret, true case *ast.ValuesExpr: if v.Column == nil { g.err = ErrUnknownColumn.GenWithStackByArgs("", "VALUES() function") } } return inNode, true } func tblInfoFromCol(from ast.ResultSetNode, col *expression.Column) *model.TableInfo { var tableList []*ast.TableName tableList = extractTableList(from, tableList, true) for _, field := range tableList { if field.Name.L == col.TblName.L { return field.TableInfo } if field.Name.L != col.TblName.L { continue } if field.Schema.L == col.DBName.L { return field.TableInfo } } return nil } func buildFuncDependCol(p LogicalPlan, cond ast.ExprNode) (*expression.Column, *expression.Column) { binOpExpr, ok := cond.(*ast.BinaryOperationExpr) if !ok { return nil, nil } if binOpExpr.Op != opcode.EQ { return nil, nil } lColExpr, ok := binOpExpr.L.(*ast.ColumnNameExpr) if !ok { return nil, nil } rColExpr, ok := binOpExpr.R.(*ast.ColumnNameExpr) if !ok { return nil, nil } lCol, err := p.Schema().FindColumn(lColExpr.Name) if err != nil { return nil, nil } rCol, err := p.Schema().FindColumn(rColExpr.Name) if err != nil { return nil, nil } return lCol, rCol } func buildWhereFuncDepend(p LogicalPlan, where ast.ExprNode) map[*expression.Column]*expression.Column { whereConditions := splitWhere(where) colDependMap := make(map[*expression.Column]*expression.Column, 2*len(whereConditions)) for _, cond := range whereConditions { lCol, rCol := buildFuncDependCol(p, cond) if lCol == nil || rCol == nil { continue } colDependMap[lCol] = rCol colDependMap[rCol] = lCol } return colDependMap } func buildJoinFuncDepend(p LogicalPlan, from ast.ResultSetNode) map[*expression.Column]*expression.Column { switch x := from.(type) { case *ast.Join: if x.On == nil { return nil } onConditions := splitWhere(x.On.Expr) colDependMap := make(map[*expression.Column]*expression.Column, len(onConditions)) for _, cond := range onConditions { lCol, rCol := buildFuncDependCol(p, cond) if lCol == nil || rCol == nil { continue } lTbl := tblInfoFromCol(x.Left, lCol) if lTbl == nil { lCol, rCol = rCol, lCol } switch x.Tp { case ast.CrossJoin: colDependMap[lCol] = rCol colDependMap[rCol] = lCol case ast.LeftJoin: colDependMap[rCol] = lCol case ast.RightJoin: colDependMap[lCol] = rCol } } return colDependMap default: return nil } } func checkColFuncDepend(p LogicalPlan, col *expression.Column, tblInfo *model.TableInfo, gbyCols map[*expression.Column]struct{}, whereDepends, joinDepends map[*expression.Column]*expression.Column) bool { for _, index := range tblInfo.Indices { if !index.Unique { continue } funcDepend := true for _, indexCol := range index.Columns { iColInfo := tblInfo.Columns[indexCol.Offset] if !mysql.HasNotNullFlag(iColInfo.Flag) { funcDepend = false break } cn := &ast.ColumnName{ Schema: col.DBName, Table: col.TblName, Name: iColInfo.Name, } iCol, err := p.Schema().FindColumn(cn) if err != nil || iCol == nil { funcDepend = false break } if _, ok := gbyCols[iCol]; ok { continue } if wCol, ok := whereDepends[iCol]; ok { if _, ok = gbyCols[wCol]; ok { continue } } if jCol, ok := joinDepends[iCol]; ok { if _, ok = gbyCols[jCol]; ok { continue } } funcDepend = false break } if funcDepend { return true } } primaryFuncDepend := true hasPrimaryField := false for _, colInfo := range tblInfo.Columns { if !mysql.HasPriKeyFlag(colInfo.Flag) { continue } hasPrimaryField = true pCol, err := p.Schema().FindColumn(&ast.ColumnName{ Schema: col.DBName, Table: col.TblName, Name: colInfo.Name, }) if err != nil { primaryFuncDepend = false break } if _, ok := gbyCols[pCol]; ok { continue } if wCol, ok := whereDepends[pCol]; ok { if _, ok = gbyCols[wCol]; ok { continue } } if jCol, ok := joinDepends[pCol]; ok { if _, ok = gbyCols[jCol]; ok { continue } } primaryFuncDepend = false break } return primaryFuncDepend && hasPrimaryField } // ErrExprLoc is for generate the ErrFieldNotInGroupBy error info type ErrExprLoc struct { Offset int Loc string } func checkExprInGroupBy(p LogicalPlan, expr ast.ExprNode, offset int, loc string, gbyCols map[*expression.Column]struct{}, gbyExprs []ast.ExprNode, notInGbyCols map[*expression.Column]ErrExprLoc) { if _, ok := expr.(*ast.AggregateFuncExpr); ok { return } if _, ok := expr.(*ast.ColumnNameExpr); !ok { for _, gbyExpr := range gbyExprs { if reflect.DeepEqual(gbyExpr, expr) { return } } } // Function `any_value` can be used in aggregation, even `ONLY_FULL_GROUP_BY` is set. // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_any-value for details if f, ok := expr.(*ast.FuncCallExpr); ok { if f.FnName.L == ast.AnyValue { return } } colMap := make(map[*expression.Column]struct{}, len(p.Schema().Columns)) allColFromExprNode(p, expr, colMap) for col := range colMap { if _, ok := gbyCols[col]; !ok { notInGbyCols[col] = ErrExprLoc{Offset: offset, Loc: loc} } } } func (b *PlanBuilder) checkOnlyFullGroupBy(p LogicalPlan, sel *ast.SelectStmt) (err error) { if sel.GroupBy != nil { err = b.checkOnlyFullGroupByWithGroupClause(p, sel) } else { err = b.checkOnlyFullGroupByWithOutGroupClause(p, sel.Fields.Fields) } return errors.Trace(err) } func (b *PlanBuilder) checkOnlyFullGroupByWithGroupClause(p LogicalPlan, sel *ast.SelectStmt) error { gbyCols := make(map[*expression.Column]struct{}, len(sel.Fields.Fields)) gbyExprs := make([]ast.ExprNode, 0, len(sel.Fields.Fields)) schema := p.Schema() for _, byItem := range sel.GroupBy.Items { if colExpr, ok := byItem.Expr.(*ast.ColumnNameExpr); ok { col, err := schema.FindColumn(colExpr.Name) if err != nil || col == nil { continue } gbyCols[col] = struct{}{} } else { gbyExprs = append(gbyExprs, byItem.Expr) } } notInGbyCols := make(map[*expression.Column]ErrExprLoc, len(sel.Fields.Fields)) for offset, field := range sel.Fields.Fields { if field.Auxiliary { continue } checkExprInGroupBy(p, field.Expr, offset, ErrExprInSelect, gbyCols, gbyExprs, notInGbyCols) } if sel.OrderBy != nil { for offset, item := range sel.OrderBy.Items { checkExprInGroupBy(p, item.Expr, offset, ErrExprInOrderBy, gbyCols, gbyExprs, notInGbyCols) } } if len(notInGbyCols) == 0 { return nil } whereDepends := buildWhereFuncDepend(p, sel.Where) joinDepends := buildJoinFuncDepend(p, sel.From.TableRefs) tblMap := make(map[*model.TableInfo]struct{}, len(notInGbyCols)) for col, errExprLoc := range notInGbyCols { tblInfo := tblInfoFromCol(sel.From.TableRefs, col) if tblInfo == nil { continue } if _, ok := tblMap[tblInfo]; ok { continue } if checkColFuncDepend(p, col, tblInfo, gbyCols, whereDepends, joinDepends) { tblMap[tblInfo] = struct{}{} continue } switch errExprLoc.Loc { case ErrExprInSelect: return ErrFieldNotInGroupBy.GenWithStackByArgs(errExprLoc.Offset+1, errExprLoc.Loc, sel.Fields.Fields[errExprLoc.Offset].Text()) case ErrExprInOrderBy: return ErrFieldNotInGroupBy.GenWithStackByArgs(errExprLoc.Offset+1, errExprLoc.Loc, sel.OrderBy.Items[errExprLoc.Offset].Expr.Text()) } return nil } return nil } func (b *PlanBuilder) checkOnlyFullGroupByWithOutGroupClause(p LogicalPlan, fields []*ast.SelectField) error { resolver := colResolverForOnlyFullGroupBy{} for idx, field := range fields { resolver.exprIdx = idx field.Accept(&resolver) err := resolver.Check() if err != nil { return errors.Trace(err) } } return nil } // colResolverForOnlyFullGroupBy visits Expr tree to find out if an Expr tree is an aggregation function. // If so, find out the first column name that not in an aggregation function. type colResolverForOnlyFullGroupBy struct { firstNonAggCol *ast.ColumnName exprIdx int firstNonAggColIdx int hasAggFuncOrAnyValue bool } func (c *colResolverForOnlyFullGroupBy) Enter(node ast.Node) (ast.Node, bool) { switch t := node.(type) { case *ast.AggregateFuncExpr: c.hasAggFuncOrAnyValue = true return node, true case *ast.FuncCallExpr: // enable function `any_value` in aggregation even `ONLY_FULL_GROUP_BY` is set if t.FnName.L == ast.AnyValue { c.hasAggFuncOrAnyValue = true return node, true } case *ast.ColumnNameExpr: if c.firstNonAggCol == nil { c.firstNonAggCol, c.firstNonAggColIdx = t.Name, c.exprIdx } return node, true case *ast.SubqueryExpr: return node, true } return node, false } func (c *colResolverForOnlyFullGroupBy) Leave(node ast.Node) (ast.Node, bool) { return node, true } func (c *colResolverForOnlyFullGroupBy) Check() error { if c.hasAggFuncOrAnyValue && c.firstNonAggCol != nil { return ErrMixOfGroupFuncAndFields.GenWithStackByArgs(c.firstNonAggColIdx+1, c.firstNonAggCol.Name.O) } return nil } type colResolver struct { p LogicalPlan cols map[*expression.Column]struct{} } func (c *colResolver) Enter(inNode ast.Node) (ast.Node, bool) { switch inNode.(type) { case *ast.ColumnNameExpr, *ast.SubqueryExpr, *ast.AggregateFuncExpr: return inNode, true } return inNode, false } func (c *colResolver) Leave(inNode ast.Node) (ast.Node, bool) { switch v := inNode.(type) { case *ast.ColumnNameExpr: col, err := c.p.Schema().FindColumn(v.Name) if err == nil && col != nil { c.cols[col] = struct{}{} } } return inNode, true } func allColFromExprNode(p LogicalPlan, n ast.Node, cols map[*expression.Column]struct{}) { extractor := &colResolver{ p: p, cols: cols, } n.Accept(extractor) } func (b *PlanBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fields []*ast.SelectField) (LogicalPlan, []expression.Expression, error) { b.curClause = groupByClause exprs := make([]expression.Expression, 0, len(gby.Items)) resolver := &gbyResolver{ ctx: b.ctx, fields: fields, schema: p.Schema(), } for _, item := range gby.Items { resolver.inExpr = false retExpr, _ := item.Expr.Accept(resolver) if resolver.err != nil { return nil, nil, errors.Trace(resolver.err) } if !resolver.isParam { item.Expr = retExpr.(ast.ExprNode) } itemExpr := retExpr.(ast.ExprNode) expr, np, err := b.rewrite(itemExpr, p, nil, true) if err != nil { return nil, nil, errors.Trace(err) } exprs = append(exprs, expr) p = np } return p, exprs, nil } func (b *PlanBuilder) unfoldWildStar(p LogicalPlan, selectFields []*ast.SelectField) (resultList []*ast.SelectField, err error) { for i, field := range selectFields { if field.WildCard == nil { resultList = append(resultList, field) continue } if field.WildCard.Table.L == "" && i > 0 { return nil, ErrInvalidWildCard } dbName := field.WildCard.Schema tblName := field.WildCard.Table findTblNameInSchema := false for _, col := range p.Schema().Columns { if (dbName.L == "" || dbName.L == col.DBName.L) && (tblName.L == "" || tblName.L == col.TblName.L) && col.ID != model.ExtraHandleID { findTblNameInSchema = true colName := &ast.ColumnNameExpr{ Name: &ast.ColumnName{ Schema: col.DBName, Table: col.TblName, Name: col.ColName, }} colName.SetType(col.GetType()) field := &ast.SelectField{Expr: colName} field.SetText(col.ColName.O) resultList = append(resultList, field) } } if !findTblNameInSchema { return nil, ErrBadTable.GenWithStackByArgs(tblName) } } return resultList, nil } func (b *PlanBuilder) pushTableHints(hints []*ast.TableOptimizerHint) bool { var sortMergeTables, INLJTables, hashJoinTables []model.CIStr for _, hint := range hints { switch hint.HintName.L { case TiDBMergeJoin: sortMergeTables = append(sortMergeTables, hint.Tables...) case TiDBIndexNestedLoopJoin: INLJTables = append(INLJTables, hint.Tables...) case TiDBHashJoin: hashJoinTables = append(hashJoinTables, hint.Tables...) default: // ignore hints that not implemented } } if len(sortMergeTables)+len(INLJTables)+len(hashJoinTables) > 0 { b.tableHintInfo = append(b.tableHintInfo, tableHintInfo{ sortMergeJoinTables: sortMergeTables, indexNestedLoopJoinTables: INLJTables, hashJoinTables: hashJoinTables, }) return true } return false } func (b *PlanBuilder) popTableHints() { b.tableHintInfo = b.tableHintInfo[:len(b.tableHintInfo)-1] } // TableHints returns the *tableHintInfo of PlanBuilder. func (b *PlanBuilder) TableHints() *tableHintInfo { if len(b.tableHintInfo) == 0 { return nil } return &(b.tableHintInfo[len(b.tableHintInfo)-1]) } func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error) { if b.pushTableHints(sel.TableHints) { // table hints are only visible in the current SELECT statement. defer b.popTableHints() } if sel.SelectStmtOpts != nil { origin := b.inStraightJoin b.inStraightJoin = sel.SelectStmtOpts.StraightJoin defer func() { b.inStraightJoin = origin }() } var ( aggFuncs []*ast.AggregateFuncExpr havingMap, orderMap, totalMap map[*ast.AggregateFuncExpr]int windowMap map[*ast.AggregateFuncExpr]int gbyCols []expression.Expression ) if sel.From != nil { p, err = b.buildResultSetNode(sel.From.TableRefs) if err != nil { return nil, errors.Trace(err) } } else { p = b.buildTableDual() } originalFields := sel.Fields.Fields sel.Fields.Fields, err = b.unfoldWildStar(p, sel.Fields.Fields) if err != nil { return nil, errors.Trace(err) } if sel.GroupBy != nil { p, gbyCols, err = b.resolveGbyExprs(p, sel.GroupBy, sel.Fields.Fields) if err != nil { return nil, errors.Trace(err) } } if b.ctx.GetSessionVars().SQLMode.HasOnlyFullGroupBy() && sel.From != nil { err = b.checkOnlyFullGroupBy(p, sel) if err != nil { return nil, errors.Trace(err) } } hasWindowFuncField := b.detectSelectWindow(sel) if hasWindowFuncField { windowMap, err = b.resolveWindowFunction(sel, p) if err != nil { return nil, err } } // We must resolve having and order by clause before build projection, // because when the query is "select a+1 as b from t having sum(b) < 0", we must replace sum(b) to sum(a+1), // which only can be done before building projection and extracting Agg functions. havingMap, orderMap, err = b.resolveHavingAndOrderBy(sel, p) if err != nil { return nil, errors.Trace(err) } if sel.Where != nil { p, err = b.buildSelection(p, sel.Where, nil) if err != nil { return nil, errors.Trace(err) } } if sel.LockTp != ast.SelectLockNone { p = b.buildSelectLock(p, sel.LockTp) } hasAgg := b.detectSelectAgg(sel) if hasAgg { aggFuncs, totalMap = b.extractAggFuncs(sel.Fields.Fields) var aggIndexMap map[int]int p, aggIndexMap, err = b.buildAggregation(p, aggFuncs, gbyCols) if err != nil { return nil, errors.Trace(err) } for k, v := range totalMap { totalMap[k] = aggIndexMap[v] } } var oldLen int // According to https://dev.mysql.com/doc/refman/8.0/en/window-functions-usage.html, // we can only process window functions after having clause, so `considerWindow` is false now. p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, totalMap, false) if err != nil { return nil, errors.Trace(err) } if sel.Having != nil { b.curClause = havingClause p, err = b.buildSelection(p, sel.Having.Expr, havingMap) if err != nil { return nil, errors.Trace(err) } } b.windowSpecs, err = buildWindowSpecs(sel.WindowSpecs) if err != nil { return nil, err } if hasWindowFuncField { // Now we build the window function fields. p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, windowMap, true) if err != nil { return nil, err } } if sel.Distinct { p = b.buildDistinct(p, oldLen) } if sel.OrderBy != nil { p, err = b.buildSort(p, sel.OrderBy.Items, orderMap) if err != nil { return nil, errors.Trace(err) } } if sel.Limit != nil { p, err = b.buildLimit(p, sel.Limit) if err != nil { return nil, errors.Trace(err) } } sel.Fields.Fields = originalFields if oldLen != p.Schema().Len() { proj := LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldLen])}.Init(b.ctx) proj.SetChildren(p) schema := expression.NewSchema(p.Schema().Clone().Columns[:oldLen]...) for _, col := range schema.Columns { col.UniqueID = b.ctx.GetSessionVars().AllocPlanColumnID() } proj.SetSchema(schema) return proj, nil } return p, nil } func (b *PlanBuilder) buildTableDual() *LogicalTableDual { return LogicalTableDual{RowCount: 1}.Init(b.ctx) } func (ds *DataSource) newExtraHandleSchemaCol() *expression.Column { return &expression.Column{ DBName: ds.DBName, TblName: ds.tableInfo.Name, ColName: model.ExtraHandleName, RetType: types.NewFieldType(mysql.TypeLonglong), UniqueID: ds.ctx.GetSessionVars().AllocPlanColumnID(), ID: model.ExtraHandleID, } } // getStatsTable gets statistics information for a table specified by "tableID". // A pseudo statistics table is returned in any of the following scenario: // 1. tidb-server started and statistics handle has not been initialized. // 2. table row count from statistics is zero. // 3. statistics is outdated. func getStatsTable(ctx sessionctx.Context, tblInfo *model.TableInfo, pid int64) *statistics.Table { statsHandle := domain.GetDomain(ctx).StatsHandle() // 1. tidb-server started and statistics handle has not been initialized. if statsHandle == nil { return statistics.PseudoTable(tblInfo) } var statsTbl *statistics.Table if pid != tblInfo.ID { statsTbl = statsHandle.GetPartitionStats(tblInfo, pid) } else { statsTbl = statsHandle.GetTableStats(tblInfo) } // 2. table row count from statistics is zero. if statsTbl.Count == 0 { return statistics.PseudoTable(tblInfo) } // 3. statistics is outdated. if statsTbl.IsOutdated() { tbl := *statsTbl tbl.Pseudo = true statsTbl = &tbl metrics.PseudoEstimation.Inc() } return statsTbl } func (b *PlanBuilder) buildDataSource(tn *ast.TableName) (LogicalPlan, error) { dbName := tn.Schema if dbName.L == "" { dbName = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB) } tbl, err := b.is.TableByName(dbName, tn.Name) if err != nil { return nil, errors.Trace(err) } tableInfo := tbl.Meta() var authErr error = nil if b.ctx.GetSessionVars().User != nil { authErr = ErrTableaccessDenied.GenWithStackByArgs("SELECT", b.ctx.GetSessionVars().User.Hostname, b.ctx.GetSessionVars().User.Username, tableInfo.Name.L) } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName.L, tableInfo.Name.L, "", authErr) if tableInfo.IsView() { return b.BuildDataSourceFromView(dbName, tableInfo) } if tableInfo.GetPartitionInfo() != nil { b.optFlag = b.optFlag | flagPartitionProcessor // check partition by name. for _, name := range tn.PartitionNames { _, err = tables.FindPartitionByName(tableInfo, name.L) if err != nil { return nil, errors.Trace(err) } } } else if len(tn.PartitionNames) != 0 { return nil, ErrPartitionClauseOnNonpartitioned } possiblePaths, err := getPossibleAccessPaths(tn.IndexHints, tableInfo) if err != nil { return nil, errors.Trace(err) } var columns []*table.Column if b.inUpdateStmt { // create table t(a int, b int). // Imagine that, There are 2 TiDB instances in the cluster, name A, B. We add a column `c` to table t in the TiDB cluster. // One of the TiDB, A, the column type in its infoschema is changed to public. And in the other TiDB, the column type is // still StateWriteReorganization. // TiDB A: insert into t values(1, 2, 3); // TiDB B: update t set a = 2 where b = 2; // If we use tbl.Cols() here, the update statement, will ignore the col `c`, and the data `3` will lost. columns = tbl.WritableCols() } else { columns = tbl.Cols() } var statisticTable *statistics.Table if _, ok := tbl.(table.PartitionedTable); !ok { statisticTable = getStatsTable(b.ctx, tbl.Meta(), tbl.Meta().ID) } ds := DataSource{ DBName: dbName, table: tbl, tableInfo: tableInfo, statisticTable: statisticTable, indexHints: tn.IndexHints, possibleAccessPaths: possiblePaths, Columns: make([]*model.ColumnInfo, 0, len(columns)), partitionNames: tn.PartitionNames, }.Init(b.ctx) var handleCol *expression.Column schema := expression.NewSchema(make([]*expression.Column, 0, len(columns))...) for _, col := range columns { ds.Columns = append(ds.Columns, col.ToInfo()) newCol := &expression.Column{ UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), DBName: dbName, TblName: tableInfo.Name, ColName: col.Name, OrigColName: col.Name, ID: col.ID, RetType: &col.FieldType, } if tableInfo.PKIsHandle && mysql.HasPriKeyFlag(col.Flag) { handleCol = newCol } schema.Append(newCol) } ds.SetSchema(schema) // We append an extra handle column to the schema when "ds" is not a memory // table e.g. table in the "INFORMATION_SCHEMA" database, and the handle // column is not the primary key of "ds". isMemDB := infoschema.IsMemoryDB(ds.DBName.L) if !isMemDB && handleCol == nil { ds.Columns = append(ds.Columns, model.NewExtraHandleColInfo()) handleCol = ds.newExtraHandleSchemaCol() schema.Append(handleCol) } if handleCol != nil { schema.TblID2Handle[tableInfo.ID] = []*expression.Column{handleCol} } var result LogicalPlan = ds // If this SQL is executed in a non-readonly transaction, we need a // "UnionScan" operator to read the modifications of former SQLs, which is // buffered in tidb-server memory. txn, err := b.ctx.Txn(false) if err != nil { return nil, errors.Trace(err) } if txn.Valid() && !txn.IsReadOnly() { us := LogicalUnionScan{}.Init(b.ctx) us.SetChildren(ds) result = us } // If this table contains any virtual generated columns, we need a // "Projection" to calculate these columns. proj, err := b.projectVirtualColumns(ds, columns) if err != nil { return nil, errors.Trace(err) } if proj != nil { proj.SetChildren(result) result = proj } return result, nil } // BuildDataSourceFromView is used to build LogicalPlan from view func (b *PlanBuilder) BuildDataSourceFromView(dbName model.CIStr, tableInfo *model.TableInfo) (LogicalPlan, error) { charset, collation := b.ctx.GetSessionVars().GetCharsetInfo() viewParser := parser.New() viewParser.EnableWindowFunc(b.ctx.GetSessionVars().EnableWindowFunction) selectNode, err := viewParser.ParseOneStmt(tableInfo.View.SelectStmt, charset, collation) if err != nil { return nil, err } originalVisitInfo := b.visitInfo b.visitInfo = make([]visitInfo, 0) selectLogicalPlan, err := b.Build(selectNode) if err != nil { return nil, err } if tableInfo.View.Security == model.SecurityDefiner { if pm := privilege.GetPrivilegeManager(b.ctx); pm != nil { for _, v := range b.visitInfo { if !pm.RequestVerificationWithUser(v.db, v.table, v.column, v.privilege, tableInfo.View.Definer) { return nil, ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) } } } b.visitInfo = b.visitInfo[:0] } b.visitInfo = append(originalVisitInfo, b.visitInfo...) projSchema := expression.NewSchema(make([]*expression.Column, 0, len(tableInfo.View.Cols))...) projExprs := make([]expression.Expression, 0, len(tableInfo.View.Cols)) for i := range tableInfo.View.Cols { col := selectLogicalPlan.Schema().FindColumnByName(tableInfo.View.Cols[i].L) if col == nil { return nil, ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) } projSchema.Append(&expression.Column{ UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), TblName: col.TblName, OrigTblName: col.OrigTblName, ColName: tableInfo.Cols()[i].Name, OrigColName: tableInfo.View.Cols[i], DBName: col.DBName, RetType: col.GetType(), }) projExprs = append(projExprs, col) } projUponView := LogicalProjection{Exprs: projExprs}.Init(b.ctx) projUponView.SetChildren(selectLogicalPlan.(LogicalPlan)) projUponView.SetSchema(projSchema) return projUponView, nil } // projectVirtualColumns is only for DataSource. If some table has virtual generated columns, // we add a projection on the original DataSource, and calculate those columns in the projection // so that plans above it can reference generated columns by their name. func (b *PlanBuilder) projectVirtualColumns(ds *DataSource, columns []*table.Column) (*LogicalProjection, error) { var hasVirtualGeneratedColumn = false for _, column := range columns { if column.IsGenerated() && !column.GeneratedStored { hasVirtualGeneratedColumn = true break } } if !hasVirtualGeneratedColumn { return nil, nil } var proj = LogicalProjection{ Exprs: make([]expression.Expression, 0, len(columns)), calculateGenCols: true, }.Init(b.ctx) for i, colExpr := range ds.Schema().Columns { var exprIsGen = false var expr expression.Expression if i < len(columns) { if columns[i].IsGenerated() && !columns[i].GeneratedStored { var err error expr, _, err = b.rewrite(columns[i].GeneratedExpr, ds, nil, true) if err != nil { return nil, errors.Trace(err) } // Because the expression might return different type from // the generated column, we should wrap a CAST on the result. expr = expression.BuildCastFunction(b.ctx, expr, colExpr.GetType()) exprIsGen = true } } if !exprIsGen { expr = colExpr } proj.Exprs = append(proj.Exprs, expr) } // Re-iterate expressions to handle those virtual generated columns that refers to the other generated columns, for // example, given: // column a, column b as (a * 2), column c as (b + 1) // we'll get: // column a, column b as (a * 2), column c as ((a * 2) + 1) // A generated column definition can refer to only generated columns occurring earlier in the table definition, so // it's safe to iterate in index-ascending order. for i, expr := range proj.Exprs { proj.Exprs[i] = expression.ColumnSubstitute(expr, ds.Schema(), proj.Exprs) } proj.SetSchema(ds.Schema().Clone()) return proj, nil } // buildApplyWithJoinType builds apply plan with outerPlan and innerPlan, which apply join with particular join type for // every row from outerPlan and the whole innerPlan. func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan LogicalPlan, tp JoinType) LogicalPlan { b.optFlag = b.optFlag | flagPredicatePushDown b.optFlag = b.optFlag | flagBuildKeyInfo b.optFlag = b.optFlag | flagDecorrelate ap := LogicalApply{LogicalJoin: LogicalJoin{JoinType: tp}}.Init(b.ctx) ap.SetChildren(outerPlan, innerPlan) ap.SetSchema(expression.MergeSchema(outerPlan.Schema(), innerPlan.Schema())) for i := outerPlan.Schema().Len(); i < ap.Schema().Len(); i++ { ap.schema.Columns[i].IsReferenced = true } return ap } // buildSemiApply builds apply plan with outerPlan and innerPlan, which apply semi-join for every row from outerPlan and the whole innerPlan. func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan LogicalPlan, condition []expression.Expression, asScalar, not bool) (LogicalPlan, error) { b.optFlag = b.optFlag | flagPredicatePushDown b.optFlag = b.optFlag | flagBuildKeyInfo b.optFlag = b.optFlag | flagDecorrelate join, err := b.buildSemiJoin(outerPlan, innerPlan, condition, asScalar, not) if err != nil { return nil, errors.Trace(err) } ap := &LogicalApply{LogicalJoin: *join} ap.tp = TypeApply ap.self = ap return ap, nil } func (b *PlanBuilder) buildMaxOneRow(p LogicalPlan) LogicalPlan { maxOneRow := LogicalMaxOneRow{}.Init(b.ctx) maxOneRow.SetChildren(p) return maxOneRow } func (b *PlanBuilder) buildSemiJoin(outerPlan, innerPlan LogicalPlan, onCondition []expression.Expression, asScalar bool, not bool) (*LogicalJoin, error) { joinPlan := LogicalJoin{}.Init(b.ctx) for i, expr := range onCondition { onCondition[i] = expr.Decorrelate(outerPlan.Schema()) } joinPlan.SetChildren(outerPlan, innerPlan) joinPlan.attachOnConds(onCondition) if asScalar { newSchema := outerPlan.Schema().Clone() newSchema.Append(&expression.Column{ ColName: model.NewCIStr(fmt.Sprintf("%d_aux_0", joinPlan.id)), RetType: types.NewFieldType(mysql.TypeTiny), IsReferenced: true, UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), }) joinPlan.SetSchema(newSchema) if not { joinPlan.JoinType = AntiLeftOuterSemiJoin } else { joinPlan.JoinType = LeftOuterSemiJoin } } else { joinPlan.SetSchema(outerPlan.Schema().Clone()) if not { joinPlan.JoinType = AntiSemiJoin } else { joinPlan.JoinType = SemiJoin } } // Apply forces to choose hash join currently, so don't worry the hints will take effect if the semi join is in one apply. if b.TableHints() != nil { outerAlias := extractTableAlias(outerPlan) innerAlias := extractTableAlias(innerPlan) if b.TableHints().ifPreferMergeJoin(outerAlias, innerAlias) { joinPlan.preferJoinType |= preferMergeJoin } if b.TableHints().ifPreferHashJoin(outerAlias, innerAlias) { joinPlan.preferJoinType |= preferHashJoin } if b.TableHints().ifPreferINLJ(innerAlias) { joinPlan.preferJoinType = preferLeftAsIndexInner } // If there're multiple join hints, they're conflict. if bits.OnesCount(joinPlan.preferJoinType) > 1 { return nil, errors.New("Join hints are conflict, you can only specify one type of join") } } return joinPlan, nil } func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { if b.pushTableHints(update.TableHints) { // table hints are only visible in the current UPDATE statement. defer b.popTableHints() } // update subquery table should be forbidden var asNameList []string asNameList = extractTableSourceAsNames(update.TableRefs.TableRefs, asNameList, true) for _, asName := range asNameList { for _, assign := range update.List { if assign.Column.Table.L == asName { return nil, ErrNonUpdatableTable.GenWithStackByArgs(asName, "UPDATE") } } } b.inUpdateStmt = true sel := &ast.SelectStmt{ Fields: &ast.FieldList{}, From: update.TableRefs, Where: update.Where, OrderBy: update.Order, Limit: update.Limit, } p, err := b.buildResultSetNode(sel.From.TableRefs) if err != nil { return nil, errors.Trace(err) } var tableList []*ast.TableName tableList = extractTableList(sel.From.TableRefs, tableList, false) for _, t := range tableList { dbName := t.Schema.L if dbName == "" { dbName = b.ctx.GetSessionVars().CurrentDB } if t.TableInfo.IsView() { return nil, errors.Errorf("update view %s is not supported now.", t.Name.O) } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "", nil) } if sel.Where != nil { p, err = b.buildSelection(p, sel.Where, nil) if err != nil { return nil, errors.Trace(err) } } if sel.OrderBy != nil { p, err = b.buildSort(p, sel.OrderBy.Items, nil) if err != nil { return nil, errors.Trace(err) } } if sel.Limit != nil { p, err = b.buildLimit(p, sel.Limit) if err != nil { return nil, errors.Trace(err) } } orderedList, np, err := b.buildUpdateLists(tableList, update.List, p) if err != nil { return nil, errors.Trace(err) } p = np updt := Update{OrderedList: orderedList}.Init(b.ctx) updt.SetSchema(p.Schema()) updt.SelectPlan, err = DoOptimize(b.optFlag, p) if err != nil { return nil, errors.Trace(err) } err = updt.ResolveIndices() return updt, err } func (b *PlanBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.Assignment, p LogicalPlan) ([]*expression.Assignment, LogicalPlan, error) { b.curClause = fieldList modifyColumns := make(map[string]struct{}, p.Schema().Len()) // Which columns are in set list. for _, assign := range list { col, _, err := p.findColumn(assign.Column) if err != nil { return nil, nil, errors.Trace(err) } columnFullName := fmt.Sprintf("%s.%s.%s", col.DBName.L, col.TblName.L, col.ColName.L) modifyColumns[columnFullName] = struct{}{} } // If columns in set list contains generated columns, raise error. // And, fill virtualAssignments here; that's for generated columns. virtualAssignments := make([]*ast.Assignment, 0) tableAsName := make(map[*model.TableInfo][]*model.CIStr) extractTableAsNameForUpdate(p, tableAsName) for _, tn := range tableList { tableInfo := tn.TableInfo tableVal, found := b.is.TableByID(tableInfo.ID) if !found { return nil, nil, infoschema.ErrTableNotExists.GenWithStackByArgs(tn.DBInfo.Name.O, tableInfo.Name.O) } for i, colInfo := range tableInfo.Columns { if !colInfo.IsGenerated() { continue } columnFullName := fmt.Sprintf("%s.%s.%s", tn.Schema.L, tn.Name.L, colInfo.Name.L) if _, ok := modifyColumns[columnFullName]; ok { return nil, nil, ErrBadGeneratedColumn.GenWithStackByArgs(colInfo.Name.O, tableInfo.Name.O) } for _, asName := range tableAsName[tableInfo] { virtualAssignments = append(virtualAssignments, &ast.Assignment{ Column: &ast.ColumnName{Table: *asName, Name: colInfo.Name}, Expr: tableVal.Cols()[i].GeneratedExpr, }) } } } newList := make([]*expression.Assignment, 0, p.Schema().Len()) allAssignments := append(list, virtualAssignments...) for i, assign := range allAssignments { col, _, err := p.findColumn(assign.Column) if err != nil { return nil, nil, errors.Trace(err) } var newExpr expression.Expression var np LogicalPlan if i < len(list) { newExpr, np, err = b.rewrite(assign.Expr, p, nil, false) } else { // rewrite with generation expression rewritePreprocess := func(expr ast.Node) ast.Node { switch x := expr.(type) { case *ast.ColumnName: return &ast.ColumnName{ Schema: assign.Column.Schema, Table: assign.Column.Table, Name: x.Name, } default: return expr } } newExpr, np, err = b.rewriteWithPreprocess(assign.Expr, p, nil, false, rewritePreprocess) } if err != nil { return nil, nil, errors.Trace(err) } newExpr = expression.BuildCastFunction(b.ctx, newExpr, col.GetType()) p = np newList = append(newList, &expression.Assignment{Col: col, Expr: newExpr}) } for _, assign := range newList { col := assign.Col dbName := col.DBName.L if dbName == "" { dbName = b.ctx.GetSessionVars().CurrentDB } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, dbName, col.OrigTblName.L, "", nil) } return newList, p, nil } // extractTableAsNameForUpdate extracts tables' alias names for update. func extractTableAsNameForUpdate(p LogicalPlan, asNames map[*model.TableInfo][]*model.CIStr) { switch x := p.(type) { case *DataSource: alias := extractTableAlias(p) if alias != nil { if _, ok := asNames[x.tableInfo]; !ok { asNames[x.tableInfo] = make([]*model.CIStr, 0, 1) } asNames[x.tableInfo] = append(asNames[x.tableInfo], alias) } case *LogicalProjection: if !x.calculateGenCols { return } ds, isDS := x.Children()[0].(*DataSource) if !isDS { // try to extract the DataSource below a LogicalUnionScan. if us, isUS := x.Children()[0].(*LogicalUnionScan); isUS { ds, isDS = us.Children()[0].(*DataSource) } } if !isDS { return } alias := extractTableAlias(x) if alias != nil { if _, ok := asNames[ds.tableInfo]; !ok { asNames[ds.tableInfo] = make([]*model.CIStr, 0, 1) } asNames[ds.tableInfo] = append(asNames[ds.tableInfo], alias) } default: for _, child := range p.Children() { extractTableAsNameForUpdate(child, asNames) } } } func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { if b.pushTableHints(delete.TableHints) { // table hints are only visible in the current DELETE statement. defer b.popTableHints() } sel := &ast.SelectStmt{ Fields: &ast.FieldList{}, From: delete.TableRefs, Where: delete.Where, OrderBy: delete.Order, Limit: delete.Limit, } p, err := b.buildResultSetNode(sel.From.TableRefs) if err != nil { return nil, errors.Trace(err) } oldSchema := p.Schema() oldLen := oldSchema.Len() if sel.Where != nil { p, err = b.buildSelection(p, sel.Where, nil) if err != nil { return nil, errors.Trace(err) } } if sel.OrderBy != nil { p, err = b.buildSort(p, sel.OrderBy.Items, nil) if err != nil { return nil, errors.Trace(err) } } if sel.Limit != nil { p, err = b.buildLimit(p, sel.Limit) if err != nil { return nil, errors.Trace(err) } } // Add a projection for the following case, otherwise the final schema will be the schema of the join. // delete from t where a in (select ...) or b in (select ...) if !delete.IsMultiTable && oldLen != p.Schema().Len() { proj := LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldLen])}.Init(b.ctx) proj.SetChildren(p) proj.SetSchema(oldSchema.Clone()) p = proj } var tables []*ast.TableName if delete.Tables != nil { tables = delete.Tables.Tables } del := Delete{ Tables: tables, IsMultiTable: delete.IsMultiTable, }.Init(b.ctx) del.SelectPlan, err = DoOptimize(b.optFlag, p) if err != nil { return nil, errors.Trace(err) } del.SetSchema(expression.NewSchema()) var tableList []*ast.TableName tableList = extractTableList(delete.TableRefs.TableRefs, tableList, true) // Collect visitInfo. if delete.Tables != nil { // Delete a, b from a, b, c, d... add a and b. for _, tn := range delete.Tables.Tables { foundMatch := false for _, v := range tableList { dbName := v.Schema.L if dbName == "" { dbName = b.ctx.GetSessionVars().CurrentDB } if (tn.Schema.L == "" || tn.Schema.L == dbName) && tn.Name.L == v.Name.L { tn.Schema.L = dbName tn.DBInfo = v.DBInfo tn.TableInfo = v.TableInfo foundMatch = true break } } if !foundMatch { var asNameList []string asNameList = extractTableSourceAsNames(delete.TableRefs.TableRefs, asNameList, false) for _, asName := range asNameList { tblName := tn.Name.L if tn.Schema.L != "" { tblName = tn.Schema.L + "." + tblName } if asName == tblName { // check sql like: `delete a from (select * from t) as a, t` return nil, ErrNonUpdatableTable.GenWithStackByArgs(tn.Name.O, "DELETE") } } // check sql like: `delete b from (select * from t) as a, t` return nil, ErrUnknownTable.GenWithStackByArgs(tn.Name.O, "MULTI DELETE") } if tn.TableInfo.IsView() { return nil, errors.Errorf("delete view %s is not supported now.", tn.Name.O) } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, tn.Schema.L, tn.TableInfo.Name.L, "", nil) } } else { // Delete from a, b, c, d. for _, v := range tableList { if v.TableInfo.IsView() { return nil, errors.Errorf("delete view %s is not supported now.", v.Name.O) } dbName := v.Schema.L if dbName == "" { dbName = b.ctx.GetSessionVars().CurrentDB } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, dbName, v.Name.L, "", nil) } } return del, nil } // buildProjectionForWindow builds the projection for expressions in the window specification that is not an column, // so after the projection, window functions only needs to deal with columns. func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, expr *ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, []property.Item, []property.Item, []expression.Expression, error) { b.optFlag |= flagEliminateProjection var items []*ast.ByItem if expr.Spec.Name.L != "" { ref, ok := b.windowSpecs[expr.Spec.Name.L] if !ok { return nil, nil, nil, nil, ErrWindowNoSuchWindow.GenWithStackByArgs(expr.Spec.Name.O) } expr.Spec = ref } else { expr.Spec.Name = model.NewCIStr("") } spec := expr.Spec if spec.Ref.L != "" { ref, ok := b.windowSpecs[spec.Ref.L] if !ok { return nil, nil, nil, nil, ErrWindowNoSuchWindow.GenWithStackByArgs(spec.Ref.O) } err := mergeWindowSpec(&spec, &ref) if err != nil { return nil, nil, nil, nil, err } } lenPartition := 0 if spec.PartitionBy != nil { items = append(items, spec.PartitionBy.Items...) lenPartition = len(spec.PartitionBy.Items) } if spec.OrderBy != nil { items = append(items, spec.OrderBy.Items...) } projLen := len(p.Schema().Columns) + len(items) + len(expr.Args) proj := LogicalProjection{Exprs: make([]expression.Expression, 0, projLen)}.Init(b.ctx) schema := expression.NewSchema(make([]*expression.Column, 0, projLen)...) for _, col := range p.Schema().Columns { proj.Exprs = append(proj.Exprs, col) schema.Append(col) } transformer := &itemTransformer{} propertyItems := make([]property.Item, 0, len(items)) for _, item := range items { newExpr, _ := item.Expr.Accept(transformer) item.Expr = newExpr.(ast.ExprNode) it, np, err := b.rewrite(item.Expr, p, aggMap, true) if err != nil { return nil, nil, nil, nil, err } p = np if col, ok := it.(*expression.Column); ok { propertyItems = append(propertyItems, property.Item{Col: col, Desc: item.Desc}) continue } proj.Exprs = append(proj.Exprs, it) col := &expression.Column{ ColName: model.NewCIStr(fmt.Sprintf("%d_proj_window_%d", p.ID(), schema.Len())), UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), RetType: it.GetType(), } schema.Append(col) propertyItems = append(propertyItems, property.Item{Col: col, Desc: item.Desc}) } newArgList := make([]expression.Expression, 0, len(expr.Args)) for _, arg := range expr.Args { newArg, np, err := b.rewrite(arg, p, aggMap, true) if err != nil { return nil, nil, nil, nil, err } p = np switch newArg.(type) { case *expression.Column, *expression.Constant: newArgList = append(newArgList, newArg) continue } proj.Exprs = append(proj.Exprs, newArg) col := &expression.Column{ ColName: model.NewCIStr(fmt.Sprintf("%d_proj_window_%d", p.ID(), schema.Len())), UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), RetType: newArg.GetType(), } schema.Append(col) newArgList = append(newArgList, col) } proj.SetSchema(schema) proj.SetChildren(p) return proj, propertyItems[:lenPartition], propertyItems[lenPartition:], newArgList, nil } // buildWindowFunctionFrameBound builds the bounds of window function frames. // For type `Rows`, the bound expr must be an unsigned integer. // For type `Range`, the bound expr must be temporal or numeric types. func (b *PlanBuilder) buildWindowFunctionFrameBound(spec *ast.WindowSpec, orderByItems []property.Item, boundClause *ast.FrameBound) (*FrameBound, error) { frameType := spec.Frame.Type bound := &FrameBound{Type: boundClause.Type, UnBounded: boundClause.UnBounded} if bound.UnBounded { return bound, nil } if frameType == ast.Rows { if bound.Type == ast.CurrentRow { return bound, nil } // Rows type does not support interval range. if boundClause.Unit != nil { return nil, ErrWindowRowsIntervalUse.GenWithStackByArgs(spec.Name) } numRows, isNull, isExpectedType := getUintFromNode(b.ctx, boundClause.Expr) if isNull || !isExpectedType { return nil, ErrWindowFrameIllegal.GenWithStackByArgs(spec.Name) } bound.Num = numRows return bound, nil } bound.CalcFuncs = make([]expression.Expression, len(orderByItems)) bound.CmpFuncs = make([]expression.CompareFunc, len(orderByItems)) if bound.Type == ast.CurrentRow { for i, item := range orderByItems { col := item.Col bound.CalcFuncs[i] = col bound.CmpFuncs[i] = expression.GetCmpFunction(col, col) } return bound, nil } if len(orderByItems) != 1 { return nil, ErrWindowRangeFrameOrderType.GenWithStackByArgs(spec.Name) } col := orderByItems[0].Col isNumeric, isTemporal := types.IsTypeNumeric(col.RetType.Tp), types.IsTypeTemporal(col.RetType.Tp) if !isNumeric && !isTemporal { return nil, ErrWindowRangeFrameOrderType.GenWithStackByArgs(spec.Name) } // Interval bounds only support order by temporal types. if boundClause.Unit != nil && isNumeric { return nil, ErrWindowRangeFrameNumericType.GenWithStackByArgs(spec.Name) } // Non-interval bound only support order by numeric types. if boundClause.Unit == nil && !isNumeric { return nil, ErrWindowRangeFrameTemporalType.GenWithStackByArgs(spec.Name) } // TODO: We also need to raise error for non-deterministic expressions, like rand(). val, err := evalAstExpr(b.ctx, boundClause.Expr) if err != nil { return nil, ErrWindowRangeBoundNotConstant.GenWithStackByArgs(spec.Name) } expr := expression.Constant{Value: val, RetType: boundClause.Expr.GetType()} // Do not raise warnings for truncate. oriIgnoreTruncate := b.ctx.GetSessionVars().StmtCtx.IgnoreTruncate b.ctx.GetSessionVars().StmtCtx.IgnoreTruncate = true uVal, isNull, err := expr.EvalInt(b.ctx, chunk.Row{}) b.ctx.GetSessionVars().StmtCtx.IgnoreTruncate = oriIgnoreTruncate if uVal < 0 || isNull || err != nil { return nil, ErrWindowFrameIllegal.GenWithStackByArgs(spec.Name) } desc := orderByItems[0].Desc if boundClause.Unit != nil { // It can be guaranteed by the parser. unitVal := boundClause.Unit.(*driver.ValueExpr) unit := expression.Constant{Value: unitVal.Datum, RetType: unitVal.GetType()} // When the order is asc: // `+` for following, and `-` for the preceding // When the order is desc, `+` becomes `-` and vice-versa. funcName := ast.DateAdd if (!desc && bound.Type == ast.Preceding) || (desc && bound.Type == ast.Following) { funcName = ast.DateSub } bound.CalcFuncs[0], err = expression.NewFunctionBase(b.ctx, funcName, col.RetType, col, &expr, &unit) if err != nil { return nil, err } bound.CmpFuncs[0] = expression.GetCmpFunction(orderByItems[0].Col, bound.CalcFuncs[0]) return bound, nil } // When the order is asc: // `+` for following, and `-` for the preceding // When the order is desc, `+` becomes `-` and vice-versa. funcName := ast.Plus if (!desc && bound.Type == ast.Preceding) || (desc && bound.Type == ast.Following) { funcName = ast.Minus } bound.CalcFuncs[0], err = expression.NewFunctionBase(b.ctx, funcName, col.RetType, col, &expr) if err != nil { return nil, err } bound.CmpFuncs[0] = expression.GetCmpFunction(orderByItems[0].Col, bound.CalcFuncs[0]) return bound, nil } // buildWindowFunctionFrame builds the window function frames. // See https://dev.mysql.com/doc/refman/8.0/en/window-functions-frames.html func (b *PlanBuilder) buildWindowFunctionFrame(spec *ast.WindowSpec, orderByItems []property.Item) (*WindowFrame, error) { frameClause := spec.Frame if frameClause == nil { return nil, nil } if frameClause.Type == ast.Groups { return nil, ErrNotSupportedYet.GenWithStackByArgs("GROUPS") } frame := &WindowFrame{Type: frameClause.Type} start := frameClause.Extent.Start if start.Type == ast.Following && start.UnBounded { return nil, ErrWindowFrameStartIllegal.GenWithStackByArgs(spec.Name) } var err error frame.Start, err = b.buildWindowFunctionFrameBound(spec, orderByItems, &start) if err != nil { return nil, err } end := frameClause.Extent.End if end.Type == ast.Preceding && end.UnBounded { return nil, ErrWindowFrameEndIllegal.GenWithStackByArgs(spec.Name) } frame.End, err = b.buildWindowFunctionFrameBound(spec, orderByItems, &end) return frame, err } func (b *PlanBuilder) buildWindowFunction(p LogicalPlan, expr *ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (*LogicalWindow, error) { p, partitionBy, orderBy, args, err := b.buildProjectionForWindow(p, expr, aggMap) if err != nil { return nil, err } needFrame := aggregation.NeedFrame(expr.F) // According to MySQL, In the absence of a frame clause, the default frame depends on whether an ORDER BY clause is present: // (1) With order by, the default frame is equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"; // (2) Without order by, the default frame is equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", // which is the same as an empty frame. if needFrame && expr.Spec.Frame == nil && len(orderBy) > 0 { expr.Spec.Frame = &ast.FrameClause{ Type: ast.Ranges, Extent: ast.FrameExtent{ Start: ast.FrameBound{Type: ast.Preceding, UnBounded: true}, End: ast.FrameBound{Type: ast.CurrentRow}, }, } } // For functions that operate on the entire partition, the frame clause will be ignored. if !needFrame && expr.Spec.Frame != nil { b.ctx.GetSessionVars().StmtCtx.AppendNote(ErrWindowFunctionIgnoresFrame.GenWithStackByArgs(expr.F, expr.Spec.Name.O)) expr.Spec.Frame = nil } frame, err := b.buildWindowFunctionFrame(&expr.Spec, orderBy) if err != nil { return nil, err } desc := aggregation.NewWindowFuncDesc(b.ctx, expr.F, args) if desc == nil { return nil, ErrWrongArguments.GenWithStackByArgs(expr.F) } // TODO: Check if the function is aggregation function after we support more functions. desc.WrapCastForAggArgs(b.ctx) window := LogicalWindow{ WindowFuncDesc: desc, PartitionBy: partitionBy, OrderBy: orderBy, Frame: frame, }.Init(b.ctx) schema := p.Schema().Clone() schema.Append(&expression.Column{ ColName: model.NewCIStr(fmt.Sprintf("%d_window_%d", window.id, p.Schema().Len())), UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), IsReferenced: true, RetType: desc.RetTp, }) window.SetChildren(p) window.SetSchema(schema) return window, nil } // resolveWindowSpec resolve window specifications for sql like `select ... from t window w1 as (w2), w2 as (partition by a)`. // We need to resolve the referenced window to get the definition of current window spec. func resolveWindowSpec(spec *ast.WindowSpec, specs map[string]ast.WindowSpec, inStack map[string]bool) error { if inStack[spec.Name.L] { return errors.Trace(ErrWindowCircularityInWindowGraph) } if spec.Ref.L == "" { return nil } ref, ok := specs[spec.Ref.L] if !ok { return ErrWindowNoSuchWindow.GenWithStackByArgs(spec.Ref.O) } inStack[spec.Name.L] = true err := resolveWindowSpec(&ref, specs, inStack) if err != nil { return err } inStack[spec.Name.L] = false return mergeWindowSpec(spec, &ref) } func mergeWindowSpec(spec, ref *ast.WindowSpec) error { if ref.Frame != nil { return ErrWindowNoInherentFrame.GenWithStackByArgs(ref.Name.O) } if ref.OrderBy != nil { if spec.OrderBy != nil { return ErrWindowNoRedefineOrderBy.GenWithStackByArgs(spec.Name.O, ref.Name.O) } spec.OrderBy = ref.OrderBy } if spec.PartitionBy != nil { return errors.Trace(ErrWindowNoChildPartitioning) } spec.PartitionBy = ref.PartitionBy spec.Ref = model.NewCIStr("") return nil } func buildWindowSpecs(specs []ast.WindowSpec) (map[string]ast.WindowSpec, error) { specsMap := make(map[string]ast.WindowSpec, len(specs)) for _, spec := range specs { if _, ok := specsMap[spec.Name.L]; ok { return nil, ErrWindowDuplicateName.GenWithStackByArgs(spec.Name.O) } specsMap[spec.Name.L] = spec } inStack := make(map[string]bool, len(specs)) for _, spec := range specs { err := resolveWindowSpec(&spec, specsMap, inStack) if err != nil { return nil, err } } return specsMap, nil } // extractTableList extracts all the TableNames from node. // If asName is true, extract AsName prior to OrigName. // Privilege check should use OrigName, while expression may use AsName. func extractTableList(node ast.ResultSetNode, input []*ast.TableName, asName bool) []*ast.TableName { switch x := node.(type) { case *ast.Join: input = extractTableList(x.Left, input, asName) input = extractTableList(x.Right, input, asName) case *ast.TableSource: if s, ok := x.Source.(*ast.TableName); ok { if x.AsName.L != "" && asName { newTableName := *s newTableName.Name = x.AsName input = append(input, &newTableName) } else { input = append(input, s) } } } return input } // extractTableSourceAsNames extracts TableSource.AsNames from node. // if onlySelectStmt is set to be true, only extracts AsNames when TableSource.Source.(type) == *ast.SelectStmt func extractTableSourceAsNames(node ast.ResultSetNode, input []string, onlySelectStmt bool) []string { switch x := node.(type) { case *ast.Join: input = extractTableSourceAsNames(x.Left, input, onlySelectStmt) input = extractTableSourceAsNames(x.Right, input, onlySelectStmt) case *ast.TableSource: if _, ok := x.Source.(*ast.SelectStmt); !ok && onlySelectStmt { break } if s, ok := x.Source.(*ast.TableName); ok { if x.AsName.L == "" { input = append(input, s.Name.L) break } } input = append(input, x.AsName.L) } return input } func appendVisitInfo(vi []visitInfo, priv mysql.PrivilegeType, db, tbl, col string, err error) []visitInfo { return append(vi, visitInfo{ privilege: priv, db: db, table: tbl, column: col, err: err, }) } func getInnerFromParenthesesAndUnaryPlus(expr ast.ExprNode) ast.ExprNode { if pexpr, ok := expr.(*ast.ParenthesesExpr); ok { return getInnerFromParenthesesAndUnaryPlus(pexpr.Expr) } if uexpr, ok := expr.(*ast.UnaryOperationExpr); ok && uexpr.Op == opcode.Plus { return getInnerFromParenthesesAndUnaryPlus(uexpr.V) } return expr }