Merge pull request #422 from pingcap/qiuyesuifeng/update-stmt-visitor

Update stmt use  FromIdentVisitor for column eval.
This commit is contained in:
qiuyesuifeng
2015-10-22 14:57:41 +08:00
8 changed files with 65 additions and 49 deletions

View File

@ -40,7 +40,7 @@ type SelectFieldsRset struct {
}
func updateSelectFieldsRefer(selectList *plans.SelectList) error {
visitor := newFromIdentVisitor(selectList.FromFields, fieldListClause)
visitor := NewFromIdentVisitor(selectList.FromFields, FieldListClause)
// we only fix un-hidden fields here, for hidden fields, it should be
// handled in their own clause, in order by or having.

View File

@ -59,7 +59,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
if v.rootIdent == i {
// The group by is an identifier, we must check it first.
index, err = checkIdentAmbiguous(i, v.selectList, groupByClause)
index, err = checkIdentAmbiguous(i, v.selectList, GroupByClause)
if err != nil {
return nil, errors.Trace(err)
}
@ -75,7 +75,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
if v.rootIdent != i {
// This identifier is the part of the group by, check ambiguous here.
index, err = checkIdentAmbiguous(i, v.selectList, groupByClause)
index, err = checkIdentAmbiguous(i, v.selectList, GroupByClause)
if err != nil {
return nil, errors.Trace(err)
}
@ -126,7 +126,7 @@ func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) {
visitor.selectList = r.SelectList
for i, e := range r.By {
pos, err := castPosition(e, r.SelectList, groupByClause)
pos, err := castPosition(e, r.SelectList, GroupByClause)
if err != nil {
return nil, errors.Trace(err)
}

View File

@ -74,7 +74,7 @@ func (v *havingVisitor) visitIdentInAggregate(i *expression.Ident) (expression.E
}
// check in select list.
index, err := checkIdentAmbiguous(i, v.selectList, havingClause)
index, err := checkIdentAmbiguous(i, v.selectList, HavingClause)
if err != nil {
return i, errors.Trace(err)
}
@ -129,7 +129,7 @@ func (v *havingVisitor) checkIdentInGroupBy(i *expression.Ident) (*expression.Id
}
func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression.Ident, bool, error) {
index, err := checkIdentAmbiguous(i, v.selectList, havingClause)
index, err := checkIdentAmbiguous(i, v.selectList, HavingClause)
if err != nil {
return i, false, errors.Trace(err)
}

View File

@ -46,39 +46,44 @@ func castIdent(e expression.Expression) *expression.Ident {
return i
}
// ClauseType is clause type.
// TODO: export clause type and move to plan?
type clauseType int
type ClauseType int
// Clause Types.
const (
noneClause clauseType = iota
onClause
whereClause
groupByClause
fieldListClause
havingClause
orderByClause
NoneClause ClauseType = iota
OnClause
WhereClause
GroupByClause
FieldListClause
HavingClause
OrderByClause
UpdateClause
)
func (clause clauseType) String() string {
func (clause ClauseType) String() string {
switch clause {
case onClause:
case OnClause:
return "on clause"
case fieldListClause:
return "field list"
case whereClause:
case WhereClause:
return "where clause"
case groupByClause:
case GroupByClause:
return "group statement"
case orderByClause:
return "order clause"
case havingClause:
case FieldListClause:
return "field list"
case HavingClause:
return "having clause"
case OrderByClause:
return "order clause"
case UpdateClause:
return "update clause"
}
return "none"
}
// castPosition returns an group/order by Position expression if e is a number.
func castPosition(e expression.Expression, selectList *plans.SelectList, clause clauseType) (*expression.Position, error) {
func castPosition(e expression.Expression, selectList *plans.SelectList, clause ClauseType) (*expression.Position, error) {
v, ok := e.(expression.Value)
if !ok {
return nil, nil
@ -98,7 +103,7 @@ func castPosition(e expression.Expression, selectList *plans.SelectList, clause
return nil, errors.Errorf("Unknown column '%d' in '%s'", position, clause)
}
if clause == groupByClause {
if clause == GroupByClause {
index := position - 1
if _, ok := selectList.AggFields[index]; ok {
return nil, errors.Errorf("Can't group on '%s'", selectList.Fields[index])
@ -109,7 +114,7 @@ func castPosition(e expression.Expression, selectList *plans.SelectList, clause
return &expression.Position{N: position}, nil
}
func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clause clauseType) (int, error) {
func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clause ClauseType) (int, error) {
index, err := selectList.CheckAmbiguous(i)
if err != nil {
return -1, errors.Errorf("Column '%s' in %s is ambiguous", i, clause)
@ -120,38 +125,40 @@ func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clau
return index, nil
}
// fromIdentVisitor can only handle identifier which reference FROM table or outer query.
// FromIdentVisitor can only handle identifier which reference FROM table or outer query.
// like in common select list, where or join on condition.
type fromIdentVisitor struct {
type FromIdentVisitor struct {
expression.BaseVisitor
fromFields []*field.ResultField
clause clauseType
FromFields []*field.ResultField
Clause ClauseType
}
func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
idx := field.GetResultFieldIndex(i.L, v.fromFields)
// VisitIdent implements Visitor interface.
func (v *FromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
idx := field.GetResultFieldIndex(i.L, v.FromFields)
if len(idx) == 1 {
i.ReferScope = expression.IdentReferFromTable
i.ReferIndex = idx[0]
return i, nil
} else if len(idx) > 1 {
return nil, errors.Errorf("Column '%s' in %s is ambiguous", i, v.clause)
return nil, errors.Errorf("Column '%s' in %s is ambiguous", i, v.Clause)
}
if v.clause == onClause {
if v.Clause == OnClause {
// on clause can't check outer query.
return nil, errors.Errorf("Unknown column '%s' in '%s'", i, v.clause)
return nil, errors.Errorf("Unknown column '%s' in '%s'", i, v.Clause)
}
// TODO: check in outer query
return i, nil
}
func newFromIdentVisitor(fromFields []*field.ResultField, clause clauseType) *fromIdentVisitor {
visitor := &fromIdentVisitor{}
// NewFromIdentVisitor creates a new FromIdentVisitor.
func NewFromIdentVisitor(fromFields []*field.ResultField, clause ClauseType) *FromIdentVisitor {
visitor := &FromIdentVisitor{}
visitor.BaseVisitor.V = visitor
visitor.fromFields = fromFields
visitor.clause = clause
visitor.FromFields = fromFields
visitor.Clause = clause
return visitor
}

View File

@ -142,7 +142,7 @@ func (r *JoinRset) buildJoinPlan(ctx context.Context, p *plans.JoinPlan, s *Join
}
if p.On != nil {
visitor := newFromIdentVisitor(p.Fields, onClause)
visitor := NewFromIdentVisitor(p.Fields, OnClause)
e, err := p.On.Accept(visitor)
if err != nil {

View File

@ -111,7 +111,7 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
if v.rootIdent == i {
// The order by is an identifier, we must check it first.
index, err = checkIdentAmbiguous(i, v.selectList, orderByClause)
index, err = checkIdentAmbiguous(i, v.selectList, OrderByClause)
if err != nil {
return nil, errors.Trace(err)
}
@ -135,7 +135,7 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
if v.rootIdent != i {
// This identifier is the part of the order by, check ambiguous here.
index, err = checkIdentAmbiguous(i, v.selectList, orderByClause)
index, err = checkIdentAmbiguous(i, v.selectList, OrderByClause)
if err != nil {
return nil, errors.Trace(err)
}
@ -191,7 +191,7 @@ func (r *OrderByRset) Plan(ctx context.Context) (plan.Plan, error) {
for i := range r.By {
e := r.By[i].Expr
pos, err := castPosition(e, r.SelectList, orderByClause)
pos, err := castPosition(e, r.SelectList, OrderByClause)
if err != nil {
return nil, errors.Trace(err)
}

View File

@ -180,7 +180,7 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl
}
func (r *WhereRset) updateWhereFieldsRefer() error {
visitor := newFromIdentVisitor(r.Src.GetFields(), whereClause)
visitor := NewFromIdentVisitor(r.Src.GetFields(), WhereClause)
e, err := r.Expr.Accept(visitor)
if err != nil {

View File

@ -28,7 +28,6 @@ import (
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/plan/plans"
"github.com/pingcap/tidb/rset"
"github.com/pingcap/tidb/rset/rsets"
"github.com/pingcap/tidb/sessionctx/variable"
@ -101,7 +100,6 @@ func getUpdateColumns(assignList []expression.Assignment, fields []*field.Result
m := make(map[int]expression.Assignment, len(assignList))
for _, v := range assignList {
name := v.ColName
if len(v.TableName) > 0 {
name = fmt.Sprintf("%s.%s", v.TableName, v.ColName)
@ -234,6 +232,17 @@ func (s *UpdateStmt) plan(ctx context.Context) (plan.Plan, error) {
return nil, errors.Trace(err)
}
}
visitor := rsets.NewFromIdentVisitor(r.GetFields(), rsets.UpdateClause)
for i := range s.List {
e, err := s.List[i].Expr.Accept(visitor)
if err != nil {
return nil, errors.Trace(err)
}
s.List[i].Expr = e
}
return r, nil
}
@ -274,9 +283,9 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
for _, row := range records {
rowData := row.Data
// Set EvalIdentFunc
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
return plans.GetIdentValue(name, p.GetFields(), rowData)
// Set ExprEvalIdentReferFunc
m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
return rowData[index], nil
}
// Update rows