From 4aeaeada1dbfccb2feffabfa577df8a99e0dcafb Mon Sep 17 00:00:00 2001 From: qiuyesuifeng Date: Tue, 20 Oct 2015 12:19:06 +0800 Subject: [PATCH] *: clean up --- field/result_field.go | 107 ++++--------------------------------- field/result_field_test.go | 46 ++-------------- plan/plans/from.go | 2 +- plan/plans/plans.go | 4 +- plan/plans/row_stack.go | 4 +- plan/plans/select_list.go | 4 +- rset/rsets/groupby.go | 2 +- rset/rsets/having.go | 5 +- rset/rsets/helper.go | 2 +- rset/rsets/orderby.go | 2 +- stmt/stmts/update.go | 4 +- 11 files changed, 30 insertions(+), 152 deletions(-) diff --git a/field/result_field.go b/field/result_field.go index fd79ab9e77..0fa1541604 100644 --- a/field/result_field.go +++ b/field/result_field.go @@ -26,20 +26,6 @@ import ( "github.com/pingcap/tidb/mysql" ) -const ( - // DefaultFieldFlag is `FieldNameFlag`. - DefaultFieldFlag = FieldNameFlag - // CheckFieldFlag includes `FieldNameFlag` and `OrgFieldNameFlag`. - CheckFieldFlag = FieldNameFlag | OrgFieldNameFlag -) - -const ( - // OrgFieldNameFlag is origin filed name flag. - OrgFieldNameFlag uint32 = 1 << iota - // FieldNameFlag is filed name flag. - FieldNameFlag -) - // ResultField provides meta data of table column. type ResultField struct { column.Col // Col.Name is OrgName. @@ -96,47 +82,17 @@ func ColToResultField(col *column.Col, tableName string) *ResultField { return rf } -// CheckAmbiguousField checks whether name is an ambiguous field in ResultFields. -func CheckAmbiguousField(name string, fields []*ResultField, flag uint32) error { - indices := GetResultFieldIndex(name, fields, flag) - if len(indices) > 1 { - return errors.Errorf("ambiguous field %s", name) - } - - return nil -} - -// CheckFieldName checks whether name is an unknown or ambiguous field in ResultFields. -func CheckFieldName(name string, fields []*ResultField, flag uint32) error { - if !ContainFieldName(name, fields, flag) { - return errors.Errorf("unknown field %s", name) - } - - return CheckAmbiguousField(name, fields, OrgFieldNameFlag) -} - -// CheckAllFieldNames checks whether names has unknown or ambiguous field in ResultFields. -func CheckAllFieldNames(names []string, fields []*ResultField, flag uint32) error { - for _, name := range names { - if err := CheckFieldName(name, fields, flag); err != nil { - return errors.Trace(err) - } - } - - return nil -} - // ContainFieldName checks whether name is in ResultFields. -func ContainFieldName(name string, fields []*ResultField, flag uint32) bool { - indices := GetResultFieldIndex(name, fields, flag) +func ContainFieldName(name string, fields []*ResultField) bool { + indices := GetResultFieldIndex(name, fields) return len(indices) > 0 } // ContainAllFieldNames checks whether names are all in ResultFields. // TODO: add alias table name support -func ContainAllFieldNames(names []string, fields []*ResultField, flag uint32) bool { +func ContainAllFieldNames(names []string, fields []*ResultField) bool { for _, name := range names { - if !ContainFieldName(name, fields, flag) { + if !ContainFieldName(name, fields) { return false } } @@ -179,55 +135,14 @@ func JoinQualifiedName(db string, table string, field string) string { } // GetResultFieldIndex gets name index in ResultFields. -func GetResultFieldIndex(name string, fields []*ResultField, flag uint32) []int { +func GetResultFieldIndex(name string, fields []*ResultField) []int { var indices []int db, table, field := SplitQualifiedName(name) - - // Check origin field name. - if flag&OrgFieldNameFlag > 0 { - for i, f := range fields { - if checkFieldsEqual(db, table, field, f.DBName, f.TableName, f.ColumnInfo.Name.L) { - indices = append(indices, i) - continue - } - } - } - - // Check alias field name. - if flag&FieldNameFlag > 0 { - for i, f := range fields { - if checkFieldsEqual(db, table, field, f.DBName, f.TableName, f.Name) { - indices = append(indices, i) - continue - } - } - } - - return indices -} - -// GetFieldIndex gets name index in Fields. -func GetFieldIndex(name string, fields []*Field, flag uint32) []int { - var indices []int - - // Check origin field name. - if flag&OrgFieldNameFlag > 0 { - for i, f := range fields { - if CheckFieldsEqual(name, f.Expr.String()) { - indices = append(indices, i) - continue - } - } - } - - // Check alias field name. - if flag&FieldNameFlag > 0 { - for i, f := range fields { - if CheckFieldsEqual(name, f.AsName) { - indices = append(indices, i) - continue - } + for i, f := range fields { + if checkFieldsEqual(db, table, field, f.DBName, f.TableName, f.Name) { + indices = append(indices, i) + continue } } @@ -267,8 +182,8 @@ func checkFieldsEqual(xdb, xtable, xfield, ydb, ytable, yfield string) bool { } // CloneFieldByName clones a ResultField in ResultFields according to name. -func CloneFieldByName(name string, fields []*ResultField, flag uint32) (*ResultField, error) { - indices := GetResultFieldIndex(name, fields, flag) +func CloneFieldByName(name string, fields []*ResultField) (*ResultField, error) { + indices := GetResultFieldIndex(name, fields) if len(indices) == 0 { return nil, errors.Errorf("unknown field %s", name) } diff --git a/field/result_field_test.go b/field/result_field_test.go index 9d58d86a20..8931e38471 100644 --- a/field/result_field_test.go +++ b/field/result_field_test.go @@ -16,7 +16,6 @@ package field_test import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/column" - "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" @@ -95,10 +94,6 @@ func (*testResultFieldSuite) TestMain(c *C) { c.Assert(rs[2].Tp, Equals, mysql.TypeVarString) c.Assert(rs[3].Tp, Equals, mysql.TypeBlob) - // For CheckAmbiguousField - err := field.CheckAmbiguousField("c1", rs, field.OrgFieldNameFlag) - c.Assert(err, IsNil) - col4 := column.Col{ ColumnInfo: model.ColumnInfo{ FieldType: *types.NewFieldType(mysql.TypeVarchar), @@ -113,51 +108,20 @@ func (*testResultFieldSuite) TestMain(c *C) { DBName: "test", } rs = []*field.ResultField{r, r1, r2} - // r1 and r2 are ambiguous: same column name but different table names - err = field.CheckAmbiguousField("c2", rs, field.OrgFieldNameFlag) - c.Assert(err, NotNil) - // r1 and r2 with different alias name - err = field.CheckAmbiguousField("c2", rs, field.FieldNameFlag) - c.Assert(err, IsNil) // For CloneFieldByName - _, err = field.CloneFieldByName("cx", rs, field.OrgFieldNameFlag) + _, err := field.CloneFieldByName("cx", rs) c.Assert(err, NotNil) - _, err = field.CloneFieldByName("c2", rs, field.OrgFieldNameFlag) - c.Assert(err, IsNil) - - // For check all fields name - names := []string{"cx"} - err = field.CheckAllFieldNames(names, rs, field.OrgFieldNameFlag) - c.Assert(err, NotNil) - names = []string{"c1"} - err = field.CheckAllFieldNames(names, rs, field.OrgFieldNameFlag) + _, err = field.CloneFieldByName("c2", rs) c.Assert(err, IsNil) // For ContainAllFieldNames - names = []string{"cx", "c2"} - b := field.ContainAllFieldNames(names, rs, field.OrgFieldNameFlag) + names := []string{"cx", "c2"} + b := field.ContainAllFieldNames(names, rs) c.Assert(b, IsFalse) names = []string{"c2", "c1"} - b = field.ContainAllFieldNames(names, rs, field.OrgFieldNameFlag) + b = field.ContainAllFieldNames(names, rs) c.Assert(b, IsTrue) - - // For GetFieldIndex - f1 := &field.Field{ - Expr: &expression.Ident{CIStr: model.NewCIStr("c1")}, - AsName: "a", - } - f2 := &field.Field{ - Expr: &expression.Ident{CIStr: model.NewCIStr("c2")}, - AsName: "a", - } - fs := []*field.Field{f1, f2} - idxs := field.GetFieldIndex("c1", fs, field.OrgFieldNameFlag) - c.Assert(idxs, HasLen, 1) - - idxs = field.GetFieldIndex("a", fs, field.FieldNameFlag) - c.Assert(idxs, HasLen, 2) - } func (*testResultFieldSuite) TestCheckWildcard(c *C) { diff --git a/plan/plans/from.go b/plan/plans/from.go index a0c1450404..679a2fa18a 100644 --- a/plan/plans/from.go +++ b/plan/plans/from.go @@ -238,7 +238,7 @@ func (r *TableDefaultPlan) filter(ctx context.Context, expr expression.Expressio colNames := expression.MentionedColumns(expr) // make sure all mentioned column names are in Fields // if not, e.g. the expr has two table like t1.c1 = t2.c2, we can't use filter - if !field.ContainAllFieldNames(colNames, r.Fields, field.DefaultFieldFlag) { + if !field.ContainAllFieldNames(colNames, r.Fields) { return r, false, nil } } diff --git a/plan/plans/plans.go b/plan/plans/plans.go index 393aa98f28..31e05ed380 100644 --- a/plan/plans/plans.go +++ b/plan/plans/plans.go @@ -48,8 +48,8 @@ func isTableOrIndex(p plan.Plan) bool { } // GetIdentValue is a function that evaluate identifier value from row. -func GetIdentValue(name string, fields []*field.ResultField, row []interface{}, flag uint32) (interface{}, error) { - indices := field.GetResultFieldIndex(name, fields, flag) +func GetIdentValue(name string, fields []*field.ResultField, row []interface{}) (interface{}, error) { + indices := field.GetResultFieldIndex(name, fields) if len(indices) == 0 { return nil, errors.Errorf("unknown field %s", name) } diff --git a/plan/plans/row_stack.go b/plan/plans/row_stack.go index 7b10919d29..0f186458f0 100644 --- a/plan/plans/row_stack.go +++ b/plan/plans/row_stack.go @@ -180,7 +180,7 @@ func getIdentValueFromOuterQuery(ctx context.Context, name string) (interface{}, // first try to get from outer table reference. if t.FromData != nil { - v, err = GetIdentValue(name, t.FromDataFields, t.FromData, field.DefaultFieldFlag) + v, err = GetIdentValue(name, t.FromDataFields, t.FromData) if err == nil { // tell current subquery using outer query subquery.SetOuterQueryUsed(ctx) @@ -190,7 +190,7 @@ func getIdentValueFromOuterQuery(ctx context.Context, name string) (interface{}, // then try to get from outer select list. if t.OutData != nil { - v, err = GetIdentValue(name, t.OutDataFields, t.OutData, field.FieldNameFlag) + v, err = GetIdentValue(name, t.OutDataFields, t.OutData) if err == nil { // tell current subquery using outer query subquery.SetOuterQueryUsed(ctx) diff --git a/plan/plans/select_list.go b/plan/plans/select_list.go index 52b76fa26a..1a8d63da6b 100644 --- a/plan/plans/select_list.go +++ b/plan/plans/select_list.go @@ -225,7 +225,7 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie // TODO: use fromIdentVisitor to cleanup. var result *field.ResultField for _, name := range names { - idx := field.GetResultFieldIndex(name, srcFields, field.DefaultFieldFlag) + idx := field.GetResultFieldIndex(name, srcFields) if len(idx) > 1 { return nil, errors.Errorf("ambiguous field %s", name) } @@ -238,7 +238,7 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie if _, ok := v.Expr.(*expression.Ident); ok { // Field is ident. - if result, err = field.CloneFieldByName(name, srcFields, field.DefaultFieldFlag); err != nil { + if result, err = field.CloneFieldByName(name, srcFields); err != nil { return nil, errors.Trace(err) } diff --git a/rset/rsets/groupby.go b/rset/rsets/groupby.go index d95ce8f073..9ec51157dc 100644 --- a/rset/rsets/groupby.go +++ b/rset/rsets/groupby.go @@ -66,7 +66,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, } // first find this identifier in FROM. - idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields) if len(idx) > 0 { i.ReferScope = expression.IdentReferFromTable i.ReferIndex = idx[0] diff --git a/rset/rsets/having.go b/rset/rsets/having.go index be43f3cc02..7ca80d4f73 100644 --- a/rset/rsets/having.go +++ b/rset/rsets/having.go @@ -66,7 +66,7 @@ func (v *havingVisitor) visitIdentInAggregate(i *expression.Ident) (expression.E // then in select list, and outer query finally. // find this identifier in FROM. - idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields) if len(idx) > 0 { i.ReferScope = expression.IdentReferFromTable i.ReferIndex = idx[0] @@ -115,8 +115,7 @@ func (v *havingVisitor) checkIdentInGroupBy(i *expression.Ident) (*expression.Id // having is unqualified name, e.g, select * from t1, t2 group by t1.c having c. // both t1 and t2 have column c, we must check ambiguous here. - - idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields) if len(idx) > 1 { return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i) } diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index 91aff99379..fa47f1d79d 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -129,7 +129,7 @@ type fromIdentVisitor struct { } func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { - idx := field.GetResultFieldIndex(i.L, v.fromFields, field.DefaultFieldFlag) + idx := field.GetResultFieldIndex(i.L, v.fromFields) if len(idx) == 1 { i.ReferScope = expression.IdentReferFromTable i.ReferIndex = idx[0] diff --git a/rset/rsets/orderby.go b/rset/rsets/orderby.go index 9412cc9faa..37ad99d0e9 100644 --- a/rset/rsets/orderby.go +++ b/rset/rsets/orderby.go @@ -126,7 +126,7 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, } // find this identifier in FROM. - idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields) if len(idx) > 0 { i.ReferScope = expression.IdentReferFromTable i.ReferIndex = idx[0] diff --git a/stmt/stmts/update.go b/stmt/stmts/update.go index 1b6f46f4dc..4635ffc71b 100644 --- a/stmt/stmts/update.go +++ b/stmt/stmts/update.go @@ -107,7 +107,7 @@ func getUpdateColumns(assignList []expression.Assignment, fields []*field.Result name = fmt.Sprintf("%s.%s", v.TableName, v.ColName) } // use result fields to check assign list, otherwise use origin table columns - idx := field.GetResultFieldIndex(name, fields, field.DefaultFieldFlag) + idx := field.GetResultFieldIndex(name, fields) if n := len(idx); n > 1 { return nil, errors.Errorf("ambiguous field %s", name) } else if n == 0 { @@ -276,7 +276,7 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { // Set EvalIdentFunc m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - return plans.GetIdentValue(name, p.GetFields(), rowData, field.DefaultFieldFlag) + return plans.GetIdentValue(name, p.GetFields(), rowData) } // Update rows