*: clean up

This commit is contained in:
qiuyesuifeng
2015-10-20 12:19:06 +08:00
parent 58e6d7fd0a
commit 4aeaeada1d
11 changed files with 30 additions and 152 deletions

View File

@ -26,20 +26,6 @@ import (
"github.com/pingcap/tidb/mysql"
)
const (
// DefaultFieldFlag is `FieldNameFlag`.
DefaultFieldFlag = FieldNameFlag
// CheckFieldFlag includes `FieldNameFlag` and `OrgFieldNameFlag`.
CheckFieldFlag = FieldNameFlag | OrgFieldNameFlag
)
const (
// OrgFieldNameFlag is origin filed name flag.
OrgFieldNameFlag uint32 = 1 << iota
// FieldNameFlag is filed name flag.
FieldNameFlag
)
// ResultField provides meta data of table column.
type ResultField struct {
column.Col // Col.Name is OrgName.
@ -96,47 +82,17 @@ func ColToResultField(col *column.Col, tableName string) *ResultField {
return rf
}
// CheckAmbiguousField checks whether name is an ambiguous field in ResultFields.
func CheckAmbiguousField(name string, fields []*ResultField, flag uint32) error {
indices := GetResultFieldIndex(name, fields, flag)
if len(indices) > 1 {
return errors.Errorf("ambiguous field %s", name)
}
return nil
}
// CheckFieldName checks whether name is an unknown or ambiguous field in ResultFields.
func CheckFieldName(name string, fields []*ResultField, flag uint32) error {
if !ContainFieldName(name, fields, flag) {
return errors.Errorf("unknown field %s", name)
}
return CheckAmbiguousField(name, fields, OrgFieldNameFlag)
}
// CheckAllFieldNames checks whether names has unknown or ambiguous field in ResultFields.
func CheckAllFieldNames(names []string, fields []*ResultField, flag uint32) error {
for _, name := range names {
if err := CheckFieldName(name, fields, flag); err != nil {
return errors.Trace(err)
}
}
return nil
}
// ContainFieldName checks whether name is in ResultFields.
func ContainFieldName(name string, fields []*ResultField, flag uint32) bool {
indices := GetResultFieldIndex(name, fields, flag)
func ContainFieldName(name string, fields []*ResultField) bool {
indices := GetResultFieldIndex(name, fields)
return len(indices) > 0
}
// ContainAllFieldNames checks whether names are all in ResultFields.
// TODO: add alias table name support
func ContainAllFieldNames(names []string, fields []*ResultField, flag uint32) bool {
func ContainAllFieldNames(names []string, fields []*ResultField) bool {
for _, name := range names {
if !ContainFieldName(name, fields, flag) {
if !ContainFieldName(name, fields) {
return false
}
}
@ -179,55 +135,14 @@ func JoinQualifiedName(db string, table string, field string) string {
}
// GetResultFieldIndex gets name index in ResultFields.
func GetResultFieldIndex(name string, fields []*ResultField, flag uint32) []int {
func GetResultFieldIndex(name string, fields []*ResultField) []int {
var indices []int
db, table, field := SplitQualifiedName(name)
// Check origin field name.
if flag&OrgFieldNameFlag > 0 {
for i, f := range fields {
if checkFieldsEqual(db, table, field, f.DBName, f.TableName, f.ColumnInfo.Name.L) {
indices = append(indices, i)
continue
}
}
}
// Check alias field name.
if flag&FieldNameFlag > 0 {
for i, f := range fields {
if checkFieldsEqual(db, table, field, f.DBName, f.TableName, f.Name) {
indices = append(indices, i)
continue
}
}
}
return indices
}
// GetFieldIndex gets name index in Fields.
func GetFieldIndex(name string, fields []*Field, flag uint32) []int {
var indices []int
// Check origin field name.
if flag&OrgFieldNameFlag > 0 {
for i, f := range fields {
if CheckFieldsEqual(name, f.Expr.String()) {
indices = append(indices, i)
continue
}
}
}
// Check alias field name.
if flag&FieldNameFlag > 0 {
for i, f := range fields {
if CheckFieldsEqual(name, f.AsName) {
indices = append(indices, i)
continue
}
for i, f := range fields {
if checkFieldsEqual(db, table, field, f.DBName, f.TableName, f.Name) {
indices = append(indices, i)
continue
}
}
@ -267,8 +182,8 @@ func checkFieldsEqual(xdb, xtable, xfield, ydb, ytable, yfield string) bool {
}
// CloneFieldByName clones a ResultField in ResultFields according to name.
func CloneFieldByName(name string, fields []*ResultField, flag uint32) (*ResultField, error) {
indices := GetResultFieldIndex(name, fields, flag)
func CloneFieldByName(name string, fields []*ResultField) (*ResultField, error) {
indices := GetResultFieldIndex(name, fields)
if len(indices) == 0 {
return nil, errors.Errorf("unknown field %s", name)
}

View File

@ -16,7 +16,6 @@ package field_test
import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
@ -95,10 +94,6 @@ func (*testResultFieldSuite) TestMain(c *C) {
c.Assert(rs[2].Tp, Equals, mysql.TypeVarString)
c.Assert(rs[3].Tp, Equals, mysql.TypeBlob)
// For CheckAmbiguousField
err := field.CheckAmbiguousField("c1", rs, field.OrgFieldNameFlag)
c.Assert(err, IsNil)
col4 := column.Col{
ColumnInfo: model.ColumnInfo{
FieldType: *types.NewFieldType(mysql.TypeVarchar),
@ -113,51 +108,20 @@ func (*testResultFieldSuite) TestMain(c *C) {
DBName: "test",
}
rs = []*field.ResultField{r, r1, r2}
// r1 and r2 are ambiguous: same column name but different table names
err = field.CheckAmbiguousField("c2", rs, field.OrgFieldNameFlag)
c.Assert(err, NotNil)
// r1 and r2 with different alias name
err = field.CheckAmbiguousField("c2", rs, field.FieldNameFlag)
c.Assert(err, IsNil)
// For CloneFieldByName
_, err = field.CloneFieldByName("cx", rs, field.OrgFieldNameFlag)
_, err := field.CloneFieldByName("cx", rs)
c.Assert(err, NotNil)
_, err = field.CloneFieldByName("c2", rs, field.OrgFieldNameFlag)
c.Assert(err, IsNil)
// For check all fields name
names := []string{"cx"}
err = field.CheckAllFieldNames(names, rs, field.OrgFieldNameFlag)
c.Assert(err, NotNil)
names = []string{"c1"}
err = field.CheckAllFieldNames(names, rs, field.OrgFieldNameFlag)
_, err = field.CloneFieldByName("c2", rs)
c.Assert(err, IsNil)
// For ContainAllFieldNames
names = []string{"cx", "c2"}
b := field.ContainAllFieldNames(names, rs, field.OrgFieldNameFlag)
names := []string{"cx", "c2"}
b := field.ContainAllFieldNames(names, rs)
c.Assert(b, IsFalse)
names = []string{"c2", "c1"}
b = field.ContainAllFieldNames(names, rs, field.OrgFieldNameFlag)
b = field.ContainAllFieldNames(names, rs)
c.Assert(b, IsTrue)
// For GetFieldIndex
f1 := &field.Field{
Expr: &expression.Ident{CIStr: model.NewCIStr("c1")},
AsName: "a",
}
f2 := &field.Field{
Expr: &expression.Ident{CIStr: model.NewCIStr("c2")},
AsName: "a",
}
fs := []*field.Field{f1, f2}
idxs := field.GetFieldIndex("c1", fs, field.OrgFieldNameFlag)
c.Assert(idxs, HasLen, 1)
idxs = field.GetFieldIndex("a", fs, field.FieldNameFlag)
c.Assert(idxs, HasLen, 2)
}
func (*testResultFieldSuite) TestCheckWildcard(c *C) {

View File

@ -238,7 +238,7 @@ func (r *TableDefaultPlan) filter(ctx context.Context, expr expression.Expressio
colNames := expression.MentionedColumns(expr)
// make sure all mentioned column names are in Fields
// if not, e.g. the expr has two table like t1.c1 = t2.c2, we can't use filter
if !field.ContainAllFieldNames(colNames, r.Fields, field.DefaultFieldFlag) {
if !field.ContainAllFieldNames(colNames, r.Fields) {
return r, false, nil
}
}

View File

@ -48,8 +48,8 @@ func isTableOrIndex(p plan.Plan) bool {
}
// GetIdentValue is a function that evaluate identifier value from row.
func GetIdentValue(name string, fields []*field.ResultField, row []interface{}, flag uint32) (interface{}, error) {
indices := field.GetResultFieldIndex(name, fields, flag)
func GetIdentValue(name string, fields []*field.ResultField, row []interface{}) (interface{}, error) {
indices := field.GetResultFieldIndex(name, fields)
if len(indices) == 0 {
return nil, errors.Errorf("unknown field %s", name)
}

View File

@ -180,7 +180,7 @@ func getIdentValueFromOuterQuery(ctx context.Context, name string) (interface{},
// first try to get from outer table reference.
if t.FromData != nil {
v, err = GetIdentValue(name, t.FromDataFields, t.FromData, field.DefaultFieldFlag)
v, err = GetIdentValue(name, t.FromDataFields, t.FromData)
if err == nil {
// tell current subquery using outer query
subquery.SetOuterQueryUsed(ctx)
@ -190,7 +190,7 @@ func getIdentValueFromOuterQuery(ctx context.Context, name string) (interface{},
// then try to get from outer select list.
if t.OutData != nil {
v, err = GetIdentValue(name, t.OutDataFields, t.OutData, field.FieldNameFlag)
v, err = GetIdentValue(name, t.OutDataFields, t.OutData)
if err == nil {
// tell current subquery using outer query
subquery.SetOuterQueryUsed(ctx)

View File

@ -225,7 +225,7 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie
// TODO: use fromIdentVisitor to cleanup.
var result *field.ResultField
for _, name := range names {
idx := field.GetResultFieldIndex(name, srcFields, field.DefaultFieldFlag)
idx := field.GetResultFieldIndex(name, srcFields)
if len(idx) > 1 {
return nil, errors.Errorf("ambiguous field %s", name)
}
@ -238,7 +238,7 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie
if _, ok := v.Expr.(*expression.Ident); ok {
// Field is ident.
if result, err = field.CloneFieldByName(name, srcFields, field.DefaultFieldFlag); err != nil {
if result, err = field.CloneFieldByName(name, srcFields); err != nil {
return nil, errors.Trace(err)
}

View File

@ -66,7 +66,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
}
// first find this identifier in FROM.
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields)
if len(idx) > 0 {
i.ReferScope = expression.IdentReferFromTable
i.ReferIndex = idx[0]

View File

@ -66,7 +66,7 @@ func (v *havingVisitor) visitIdentInAggregate(i *expression.Ident) (expression.E
// then in select list, and outer query finally.
// find this identifier in FROM.
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields)
if len(idx) > 0 {
i.ReferScope = expression.IdentReferFromTable
i.ReferIndex = idx[0]
@ -115,8 +115,7 @@ func (v *havingVisitor) checkIdentInGroupBy(i *expression.Ident) (*expression.Id
// having is unqualified name, e.g, select * from t1, t2 group by t1.c having c.
// both t1 and t2 have column c, we must check ambiguous here.
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields)
if len(idx) > 1 {
return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i)
}

View File

@ -129,7 +129,7 @@ type fromIdentVisitor struct {
}
func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
idx := field.GetResultFieldIndex(i.L, v.fromFields, field.DefaultFieldFlag)
idx := field.GetResultFieldIndex(i.L, v.fromFields)
if len(idx) == 1 {
i.ReferScope = expression.IdentReferFromTable
i.ReferIndex = idx[0]

View File

@ -126,7 +126,7 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
}
// find this identifier in FROM.
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields)
if len(idx) > 0 {
i.ReferScope = expression.IdentReferFromTable
i.ReferIndex = idx[0]

View File

@ -107,7 +107,7 @@ func getUpdateColumns(assignList []expression.Assignment, fields []*field.Result
name = fmt.Sprintf("%s.%s", v.TableName, v.ColName)
}
// use result fields to check assign list, otherwise use origin table columns
idx := field.GetResultFieldIndex(name, fields, field.DefaultFieldFlag)
idx := field.GetResultFieldIndex(name, fields)
if n := len(idx); n > 1 {
return nil, errors.Errorf("ambiguous field %s", name)
} else if n == 0 {
@ -276,7 +276,7 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
// Set EvalIdentFunc
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
return plans.GetIdentValue(name, p.GetFields(), rowData, field.DefaultFieldFlag)
return plans.GetIdentValue(name, p.GetFields(), rowData)
}
// Update rows