Merge pull request #422 from pingcap/qiuyesuifeng/update-stmt-visitor
Update stmt use FromIdentVisitor for column eval.
This commit is contained in:
@ -40,7 +40,7 @@ type SelectFieldsRset struct {
|
||||
}
|
||||
|
||||
func updateSelectFieldsRefer(selectList *plans.SelectList) error {
|
||||
visitor := newFromIdentVisitor(selectList.FromFields, fieldListClause)
|
||||
visitor := NewFromIdentVisitor(selectList.FromFields, FieldListClause)
|
||||
|
||||
// we only fix un-hidden fields here, for hidden fields, it should be
|
||||
// handled in their own clause, in order by or having.
|
||||
|
||||
@ -59,7 +59,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
|
||||
|
||||
if v.rootIdent == i {
|
||||
// The group by is an identifier, we must check it first.
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, groupByClause)
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, GroupByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
@ -75,7 +75,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
|
||||
|
||||
if v.rootIdent != i {
|
||||
// This identifier is the part of the group by, check ambiguous here.
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, groupByClause)
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, GroupByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
@ -126,7 +126,7 @@ func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) {
|
||||
visitor.selectList = r.SelectList
|
||||
|
||||
for i, e := range r.By {
|
||||
pos, err := castPosition(e, r.SelectList, groupByClause)
|
||||
pos, err := castPosition(e, r.SelectList, GroupByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -74,7 +74,7 @@ func (v *havingVisitor) visitIdentInAggregate(i *expression.Ident) (expression.E
|
||||
}
|
||||
|
||||
// check in select list.
|
||||
index, err := checkIdentAmbiguous(i, v.selectList, havingClause)
|
||||
index, err := checkIdentAmbiguous(i, v.selectList, HavingClause)
|
||||
if err != nil {
|
||||
return i, errors.Trace(err)
|
||||
}
|
||||
@ -129,7 +129,7 @@ func (v *havingVisitor) checkIdentInGroupBy(i *expression.Ident) (*expression.Id
|
||||
}
|
||||
|
||||
func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression.Ident, bool, error) {
|
||||
index, err := checkIdentAmbiguous(i, v.selectList, havingClause)
|
||||
index, err := checkIdentAmbiguous(i, v.selectList, HavingClause)
|
||||
if err != nil {
|
||||
return i, false, errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -46,39 +46,44 @@ func castIdent(e expression.Expression) *expression.Ident {
|
||||
return i
|
||||
}
|
||||
|
||||
// ClauseType is clause type.
|
||||
// TODO: export clause type and move to plan?
|
||||
type clauseType int
|
||||
type ClauseType int
|
||||
|
||||
// Clause Types.
|
||||
const (
|
||||
noneClause clauseType = iota
|
||||
onClause
|
||||
whereClause
|
||||
groupByClause
|
||||
fieldListClause
|
||||
havingClause
|
||||
orderByClause
|
||||
NoneClause ClauseType = iota
|
||||
OnClause
|
||||
WhereClause
|
||||
GroupByClause
|
||||
FieldListClause
|
||||
HavingClause
|
||||
OrderByClause
|
||||
UpdateClause
|
||||
)
|
||||
|
||||
func (clause clauseType) String() string {
|
||||
func (clause ClauseType) String() string {
|
||||
switch clause {
|
||||
case onClause:
|
||||
case OnClause:
|
||||
return "on clause"
|
||||
case fieldListClause:
|
||||
return "field list"
|
||||
case whereClause:
|
||||
case WhereClause:
|
||||
return "where clause"
|
||||
case groupByClause:
|
||||
case GroupByClause:
|
||||
return "group statement"
|
||||
case orderByClause:
|
||||
return "order clause"
|
||||
case havingClause:
|
||||
case FieldListClause:
|
||||
return "field list"
|
||||
case HavingClause:
|
||||
return "having clause"
|
||||
case OrderByClause:
|
||||
return "order clause"
|
||||
case UpdateClause:
|
||||
return "update clause"
|
||||
}
|
||||
return "none"
|
||||
}
|
||||
|
||||
// castPosition returns an group/order by Position expression if e is a number.
|
||||
func castPosition(e expression.Expression, selectList *plans.SelectList, clause clauseType) (*expression.Position, error) {
|
||||
func castPosition(e expression.Expression, selectList *plans.SelectList, clause ClauseType) (*expression.Position, error) {
|
||||
v, ok := e.(expression.Value)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
@ -98,7 +103,7 @@ func castPosition(e expression.Expression, selectList *plans.SelectList, clause
|
||||
return nil, errors.Errorf("Unknown column '%d' in '%s'", position, clause)
|
||||
}
|
||||
|
||||
if clause == groupByClause {
|
||||
if clause == GroupByClause {
|
||||
index := position - 1
|
||||
if _, ok := selectList.AggFields[index]; ok {
|
||||
return nil, errors.Errorf("Can't group on '%s'", selectList.Fields[index])
|
||||
@ -109,7 +114,7 @@ func castPosition(e expression.Expression, selectList *plans.SelectList, clause
|
||||
return &expression.Position{N: position}, nil
|
||||
}
|
||||
|
||||
func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clause clauseType) (int, error) {
|
||||
func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clause ClauseType) (int, error) {
|
||||
index, err := selectList.CheckAmbiguous(i)
|
||||
if err != nil {
|
||||
return -1, errors.Errorf("Column '%s' in %s is ambiguous", i, clause)
|
||||
@ -120,38 +125,40 @@ func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clau
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// fromIdentVisitor can only handle identifier which reference FROM table or outer query.
|
||||
// FromIdentVisitor can only handle identifier which reference FROM table or outer query.
|
||||
// like in common select list, where or join on condition.
|
||||
type fromIdentVisitor struct {
|
||||
type FromIdentVisitor struct {
|
||||
expression.BaseVisitor
|
||||
fromFields []*field.ResultField
|
||||
clause clauseType
|
||||
FromFields []*field.ResultField
|
||||
Clause ClauseType
|
||||
}
|
||||
|
||||
func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
|
||||
idx := field.GetResultFieldIndex(i.L, v.fromFields)
|
||||
// VisitIdent implements Visitor interface.
|
||||
func (v *FromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
|
||||
idx := field.GetResultFieldIndex(i.L, v.FromFields)
|
||||
if len(idx) == 1 {
|
||||
i.ReferScope = expression.IdentReferFromTable
|
||||
i.ReferIndex = idx[0]
|
||||
return i, nil
|
||||
} else if len(idx) > 1 {
|
||||
return nil, errors.Errorf("Column '%s' in %s is ambiguous", i, v.clause)
|
||||
return nil, errors.Errorf("Column '%s' in %s is ambiguous", i, v.Clause)
|
||||
}
|
||||
|
||||
if v.clause == onClause {
|
||||
if v.Clause == OnClause {
|
||||
// on clause can't check outer query.
|
||||
return nil, errors.Errorf("Unknown column '%s' in '%s'", i, v.clause)
|
||||
return nil, errors.Errorf("Unknown column '%s' in '%s'", i, v.Clause)
|
||||
}
|
||||
|
||||
// TODO: check in outer query
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func newFromIdentVisitor(fromFields []*field.ResultField, clause clauseType) *fromIdentVisitor {
|
||||
visitor := &fromIdentVisitor{}
|
||||
// NewFromIdentVisitor creates a new FromIdentVisitor.
|
||||
func NewFromIdentVisitor(fromFields []*field.ResultField, clause ClauseType) *FromIdentVisitor {
|
||||
visitor := &FromIdentVisitor{}
|
||||
visitor.BaseVisitor.V = visitor
|
||||
visitor.fromFields = fromFields
|
||||
visitor.clause = clause
|
||||
visitor.FromFields = fromFields
|
||||
visitor.Clause = clause
|
||||
|
||||
return visitor
|
||||
}
|
||||
|
||||
@ -142,7 +142,7 @@ func (r *JoinRset) buildJoinPlan(ctx context.Context, p *plans.JoinPlan, s *Join
|
||||
}
|
||||
|
||||
if p.On != nil {
|
||||
visitor := newFromIdentVisitor(p.Fields, onClause)
|
||||
visitor := NewFromIdentVisitor(p.Fields, OnClause)
|
||||
|
||||
e, err := p.On.Accept(visitor)
|
||||
if err != nil {
|
||||
|
||||
@ -111,7 +111,7 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
|
||||
|
||||
if v.rootIdent == i {
|
||||
// The order by is an identifier, we must check it first.
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, orderByClause)
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, OrderByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
@ -135,7 +135,7 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
|
||||
|
||||
if v.rootIdent != i {
|
||||
// This identifier is the part of the order by, check ambiguous here.
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, orderByClause)
|
||||
index, err = checkIdentAmbiguous(i, v.selectList, OrderByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
@ -191,7 +191,7 @@ func (r *OrderByRset) Plan(ctx context.Context) (plan.Plan, error) {
|
||||
for i := range r.By {
|
||||
e := r.By[i].Expr
|
||||
|
||||
pos, err := castPosition(e, r.SelectList, orderByClause)
|
||||
pos, err := castPosition(e, r.SelectList, OrderByClause)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -180,7 +180,7 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl
|
||||
}
|
||||
|
||||
func (r *WhereRset) updateWhereFieldsRefer() error {
|
||||
visitor := newFromIdentVisitor(r.Src.GetFields(), whereClause)
|
||||
visitor := NewFromIdentVisitor(r.Src.GetFields(), WhereClause)
|
||||
|
||||
e, err := r.Expr.Accept(visitor)
|
||||
if err != nil {
|
||||
|
||||
@ -28,7 +28,6 @@ import (
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/plan"
|
||||
"github.com/pingcap/tidb/plan/plans"
|
||||
"github.com/pingcap/tidb/rset"
|
||||
"github.com/pingcap/tidb/rset/rsets"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
@ -101,7 +100,6 @@ func getUpdateColumns(assignList []expression.Assignment, fields []*field.Result
|
||||
m := make(map[int]expression.Assignment, len(assignList))
|
||||
|
||||
for _, v := range assignList {
|
||||
|
||||
name := v.ColName
|
||||
if len(v.TableName) > 0 {
|
||||
name = fmt.Sprintf("%s.%s", v.TableName, v.ColName)
|
||||
@ -234,6 +232,17 @@ func (s *UpdateStmt) plan(ctx context.Context) (plan.Plan, error) {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
visitor := rsets.NewFromIdentVisitor(r.GetFields(), rsets.UpdateClause)
|
||||
for i := range s.List {
|
||||
e, err := s.List[i].Expr.Accept(visitor)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
s.List[i].Expr = e
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
@ -274,9 +283,9 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
|
||||
for _, row := range records {
|
||||
rowData := row.Data
|
||||
|
||||
// Set EvalIdentFunc
|
||||
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
return plans.GetIdentValue(name, p.GetFields(), rowData)
|
||||
// Set ExprEvalIdentReferFunc
|
||||
m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
|
||||
return rowData[index], nil
|
||||
}
|
||||
|
||||
// Update rows
|
||||
|
||||
Reference in New Issue
Block a user