From ee5353e9332dcc8a45583de36c0ccddc407f4d0c Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Tue, 20 Oct 2015 16:35:34 +0800 Subject: [PATCH 1/2] rset: update FromIdentVisitor exported. --- rset/rsets/fields.go | 2 +- rset/rsets/groupby.go | 6 ++-- rset/rsets/having.go | 4 +-- rset/rsets/helper.go | 73 ++++++++++++++++++++++++------------------- rset/rsets/join.go | 2 +- rset/rsets/orderby.go | 6 ++-- rset/rsets/where.go | 2 +- 7 files changed, 51 insertions(+), 44 deletions(-) diff --git a/rset/rsets/fields.go b/rset/rsets/fields.go index ffe6176409..51c200afcd 100644 --- a/rset/rsets/fields.go +++ b/rset/rsets/fields.go @@ -40,7 +40,7 @@ type SelectFieldsRset struct { } func updateSelectFieldsRefer(selectList *plans.SelectList) error { - visitor := newFromIdentVisitor(selectList.FromFields, fieldListClause) + visitor := NewFromIdentVisitor(selectList.FromFields, FieldListClause) // we only fix un-hidden fields here, for hidden fields, it should be // handled in their own clause, in order by or having. diff --git a/rset/rsets/groupby.go b/rset/rsets/groupby.go index 9ec51157dc..b3e84a2a9a 100644 --- a/rset/rsets/groupby.go +++ b/rset/rsets/groupby.go @@ -59,7 +59,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, if v.rootIdent == i { // The group by is an identifier, we must check it first. - index, err = checkIdentAmbiguous(i, v.selectList, groupByClause) + index, err = checkIdentAmbiguous(i, v.selectList, GroupByClause) if err != nil { return nil, errors.Trace(err) } @@ -75,7 +75,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, if v.rootIdent != i { // This identifier is the part of the group by, check ambiguous here. - index, err = checkIdentAmbiguous(i, v.selectList, groupByClause) + index, err = checkIdentAmbiguous(i, v.selectList, GroupByClause) if err != nil { return nil, errors.Trace(err) } @@ -126,7 +126,7 @@ func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) { visitor.selectList = r.SelectList for i, e := range r.By { - pos, err := castPosition(e, r.SelectList, groupByClause) + pos, err := castPosition(e, r.SelectList, GroupByClause) if err != nil { return nil, errors.Trace(err) } diff --git a/rset/rsets/having.go b/rset/rsets/having.go index 7ca80d4f73..ab5d86b4d6 100644 --- a/rset/rsets/having.go +++ b/rset/rsets/having.go @@ -74,7 +74,7 @@ func (v *havingVisitor) visitIdentInAggregate(i *expression.Ident) (expression.E } // check in select list. - index, err := checkIdentAmbiguous(i, v.selectList, havingClause) + index, err := checkIdentAmbiguous(i, v.selectList, HavingClause) if err != nil { return i, errors.Trace(err) } @@ -129,7 +129,7 @@ func (v *havingVisitor) checkIdentInGroupBy(i *expression.Ident) (*expression.Id } func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression.Ident, bool, error) { - index, err := checkIdentAmbiguous(i, v.selectList, havingClause) + index, err := checkIdentAmbiguous(i, v.selectList, HavingClause) if err != nil { return i, false, errors.Trace(err) } diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index fa47f1d79d..7c47982a16 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -46,39 +46,44 @@ func castIdent(e expression.Expression) *expression.Ident { return i } +// ClauseType is clause type. // TODO: export clause type and move to plan? -type clauseType int +type ClauseType int +// Clause Types. const ( - noneClause clauseType = iota - onClause - whereClause - groupByClause - fieldListClause - havingClause - orderByClause + NoneClause ClauseType = iota + OnClause + WhereClause + GroupByClause + FieldListClause + HavingClause + OrderByClause + UpdateClause ) -func (clause clauseType) String() string { +func (clause ClauseType) String() string { switch clause { - case onClause: + case OnClause: return "on clause" - case fieldListClause: - return "field list" - case whereClause: + case WhereClause: return "where clause" - case groupByClause: + case GroupByClause: return "group statement" - case orderByClause: - return "order clause" - case havingClause: + case FieldListClause: + return "field list" + case HavingClause: return "having clause" + case OrderByClause: + return "order clause" + case UpdateClause: + return "update clause" } return "none" } // castPosition returns an group/order by Position expression if e is a number. -func castPosition(e expression.Expression, selectList *plans.SelectList, clause clauseType) (*expression.Position, error) { +func castPosition(e expression.Expression, selectList *plans.SelectList, clause ClauseType) (*expression.Position, error) { v, ok := e.(expression.Value) if !ok { return nil, nil @@ -98,7 +103,7 @@ func castPosition(e expression.Expression, selectList *plans.SelectList, clause return nil, errors.Errorf("Unknown column '%d' in '%s'", position, clause) } - if clause == groupByClause { + if clause == GroupByClause { index := position - 1 if _, ok := selectList.AggFields[index]; ok { return nil, errors.Errorf("Can't group on '%s'", selectList.Fields[index]) @@ -109,7 +114,7 @@ func castPosition(e expression.Expression, selectList *plans.SelectList, clause return &expression.Position{N: position}, nil } -func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clause clauseType) (int, error) { +func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clause ClauseType) (int, error) { index, err := selectList.CheckAmbiguous(i) if err != nil { return -1, errors.Errorf("Column '%s' in %s is ambiguous", i, clause) @@ -120,38 +125,40 @@ func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clau return index, nil } -// fromIdentVisitor can only handle identifier which reference FROM table or outer query. +// FromIdentVisitor can only handle identifier which reference FROM table or outer query. // like in common select list, where or join on condition. -type fromIdentVisitor struct { +type FromIdentVisitor struct { expression.BaseVisitor - fromFields []*field.ResultField - clause clauseType + FromFields []*field.ResultField + Clause ClauseType } -func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { - idx := field.GetResultFieldIndex(i.L, v.fromFields) +// VisitIdent implements Visitor interface. +func (v *FromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { + idx := field.GetResultFieldIndex(i.L, v.FromFields) if len(idx) == 1 { i.ReferScope = expression.IdentReferFromTable i.ReferIndex = idx[0] return i, nil } else if len(idx) > 1 { - return nil, errors.Errorf("Column '%s' in %s is ambiguous", i, v.clause) + return nil, errors.Errorf("Column '%s' in %s is ambiguous", i, v.Clause) } - if v.clause == onClause { + if v.Clause == OnClause { // on clause can't check outer query. - return nil, errors.Errorf("Unknown column '%s' in '%s'", i, v.clause) + return nil, errors.Errorf("Unknown column '%s' in '%s'", i, v.Clause) } // TODO: check in outer query return i, nil } -func newFromIdentVisitor(fromFields []*field.ResultField, clause clauseType) *fromIdentVisitor { - visitor := &fromIdentVisitor{} +// NewFromIdentVisitor creates a new FromIdentVisitor. +func NewFromIdentVisitor(fromFields []*field.ResultField, clause ClauseType) *FromIdentVisitor { + visitor := &FromIdentVisitor{} visitor.BaseVisitor.V = visitor - visitor.fromFields = fromFields - visitor.clause = clause + visitor.FromFields = fromFields + visitor.Clause = clause return visitor } diff --git a/rset/rsets/join.go b/rset/rsets/join.go index 97e7c8e583..8d343cd51b 100644 --- a/rset/rsets/join.go +++ b/rset/rsets/join.go @@ -142,7 +142,7 @@ func (r *JoinRset) buildJoinPlan(ctx context.Context, p *plans.JoinPlan, s *Join } if p.On != nil { - visitor := newFromIdentVisitor(p.Fields, onClause) + visitor := NewFromIdentVisitor(p.Fields, OnClause) e, err := p.On.Accept(visitor) if err != nil { diff --git a/rset/rsets/orderby.go b/rset/rsets/orderby.go index 37ad99d0e9..b6fd4b3c22 100644 --- a/rset/rsets/orderby.go +++ b/rset/rsets/orderby.go @@ -111,7 +111,7 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, if v.rootIdent == i { // The order by is an identifier, we must check it first. - index, err = checkIdentAmbiguous(i, v.selectList, orderByClause) + index, err = checkIdentAmbiguous(i, v.selectList, OrderByClause) if err != nil { return nil, errors.Trace(err) } @@ -135,7 +135,7 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, if v.rootIdent != i { // This identifier is the part of the order by, check ambiguous here. - index, err = checkIdentAmbiguous(i, v.selectList, orderByClause) + index, err = checkIdentAmbiguous(i, v.selectList, OrderByClause) if err != nil { return nil, errors.Trace(err) } @@ -191,7 +191,7 @@ func (r *OrderByRset) Plan(ctx context.Context) (plan.Plan, error) { for i := range r.By { e := r.By[i].Expr - pos, err := castPosition(e, r.SelectList, orderByClause) + pos, err := castPosition(e, r.SelectList, OrderByClause) if err != nil { return nil, errors.Trace(err) } diff --git a/rset/rsets/where.go b/rset/rsets/where.go index 3a583fe010..4b6ed062be 100644 --- a/rset/rsets/where.go +++ b/rset/rsets/where.go @@ -180,7 +180,7 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl } func (r *WhereRset) updateWhereFieldsRefer() error { - visitor := newFromIdentVisitor(r.Src.GetFields(), whereClause) + visitor := NewFromIdentVisitor(r.Src.GetFields(), WhereClause) e, err := r.Expr.Accept(visitor) if err != nil { From bd5d23d5d58b3c4cc4596f16f73c802854ac1bb0 Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Tue, 20 Oct 2015 16:36:00 +0800 Subject: [PATCH 2/2] stmt: update stmt use FromIdentVisitor. --- stmt/stmts/update.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/stmt/stmts/update.go b/stmt/stmts/update.go index 4635ffc71b..abfd0544e4 100644 --- a/stmt/stmts/update.go +++ b/stmt/stmts/update.go @@ -28,7 +28,6 @@ import ( "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" - "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/rset/rsets" "github.com/pingcap/tidb/sessionctx/variable" @@ -101,7 +100,6 @@ func getUpdateColumns(assignList []expression.Assignment, fields []*field.Result m := make(map[int]expression.Assignment, len(assignList)) for _, v := range assignList { - name := v.ColName if len(v.TableName) > 0 { name = fmt.Sprintf("%s.%s", v.TableName, v.ColName) @@ -234,6 +232,17 @@ func (s *UpdateStmt) plan(ctx context.Context) (plan.Plan, error) { return nil, errors.Trace(err) } } + + visitor := rsets.NewFromIdentVisitor(r.GetFields(), rsets.UpdateClause) + for i := range s.List { + e, err := s.List[i].Expr.Accept(visitor) + if err != nil { + return nil, errors.Trace(err) + } + + s.List[i].Expr = e + } + return r, nil } @@ -274,9 +283,9 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { for _, row := range records { rowData := row.Data - // Set EvalIdentFunc - m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - return plans.GetIdentValue(name, p.GetFields(), rowData) + // Set ExprEvalIdentReferFunc + m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + return rowData[index], nil } // Update rows