*: clean up
This commit is contained in:
@ -26,20 +26,6 @@ import (
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultFieldFlag is `FieldNameFlag`.
|
||||
DefaultFieldFlag = FieldNameFlag
|
||||
// CheckFieldFlag includes `FieldNameFlag` and `OrgFieldNameFlag`.
|
||||
CheckFieldFlag = FieldNameFlag | OrgFieldNameFlag
|
||||
)
|
||||
|
||||
const (
|
||||
// OrgFieldNameFlag is origin filed name flag.
|
||||
OrgFieldNameFlag uint32 = 1 << iota
|
||||
// FieldNameFlag is filed name flag.
|
||||
FieldNameFlag
|
||||
)
|
||||
|
||||
// ResultField provides meta data of table column.
|
||||
type ResultField struct {
|
||||
column.Col // Col.Name is OrgName.
|
||||
@ -96,47 +82,17 @@ func ColToResultField(col *column.Col, tableName string) *ResultField {
|
||||
return rf
|
||||
}
|
||||
|
||||
// CheckAmbiguousField checks whether name is an ambiguous field in ResultFields.
|
||||
func CheckAmbiguousField(name string, fields []*ResultField, flag uint32) error {
|
||||
indices := GetResultFieldIndex(name, fields, flag)
|
||||
if len(indices) > 1 {
|
||||
return errors.Errorf("ambiguous field %s", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckFieldName checks whether name is an unknown or ambiguous field in ResultFields.
|
||||
func CheckFieldName(name string, fields []*ResultField, flag uint32) error {
|
||||
if !ContainFieldName(name, fields, flag) {
|
||||
return errors.Errorf("unknown field %s", name)
|
||||
}
|
||||
|
||||
return CheckAmbiguousField(name, fields, OrgFieldNameFlag)
|
||||
}
|
||||
|
||||
// CheckAllFieldNames checks whether names has unknown or ambiguous field in ResultFields.
|
||||
func CheckAllFieldNames(names []string, fields []*ResultField, flag uint32) error {
|
||||
for _, name := range names {
|
||||
if err := CheckFieldName(name, fields, flag); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContainFieldName checks whether name is in ResultFields.
|
||||
func ContainFieldName(name string, fields []*ResultField, flag uint32) bool {
|
||||
indices := GetResultFieldIndex(name, fields, flag)
|
||||
func ContainFieldName(name string, fields []*ResultField) bool {
|
||||
indices := GetResultFieldIndex(name, fields)
|
||||
return len(indices) > 0
|
||||
}
|
||||
|
||||
// ContainAllFieldNames checks whether names are all in ResultFields.
|
||||
// TODO: add alias table name support
|
||||
func ContainAllFieldNames(names []string, fields []*ResultField, flag uint32) bool {
|
||||
func ContainAllFieldNames(names []string, fields []*ResultField) bool {
|
||||
for _, name := range names {
|
||||
if !ContainFieldName(name, fields, flag) {
|
||||
if !ContainFieldName(name, fields) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@ -179,55 +135,14 @@ func JoinQualifiedName(db string, table string, field string) string {
|
||||
}
|
||||
|
||||
// GetResultFieldIndex gets name index in ResultFields.
|
||||
func GetResultFieldIndex(name string, fields []*ResultField, flag uint32) []int {
|
||||
func GetResultFieldIndex(name string, fields []*ResultField) []int {
|
||||
var indices []int
|
||||
|
||||
db, table, field := SplitQualifiedName(name)
|
||||
|
||||
// Check origin field name.
|
||||
if flag&OrgFieldNameFlag > 0 {
|
||||
for i, f := range fields {
|
||||
if checkFieldsEqual(db, table, field, f.DBName, f.TableName, f.ColumnInfo.Name.L) {
|
||||
indices = append(indices, i)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check alias field name.
|
||||
if flag&FieldNameFlag > 0 {
|
||||
for i, f := range fields {
|
||||
if checkFieldsEqual(db, table, field, f.DBName, f.TableName, f.Name) {
|
||||
indices = append(indices, i)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return indices
|
||||
}
|
||||
|
||||
// GetFieldIndex gets name index in Fields.
|
||||
func GetFieldIndex(name string, fields []*Field, flag uint32) []int {
|
||||
var indices []int
|
||||
|
||||
// Check origin field name.
|
||||
if flag&OrgFieldNameFlag > 0 {
|
||||
for i, f := range fields {
|
||||
if CheckFieldsEqual(name, f.Expr.String()) {
|
||||
indices = append(indices, i)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check alias field name.
|
||||
if flag&FieldNameFlag > 0 {
|
||||
for i, f := range fields {
|
||||
if CheckFieldsEqual(name, f.AsName) {
|
||||
indices = append(indices, i)
|
||||
continue
|
||||
}
|
||||
for i, f := range fields {
|
||||
if checkFieldsEqual(db, table, field, f.DBName, f.TableName, f.Name) {
|
||||
indices = append(indices, i)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@ -267,8 +182,8 @@ func checkFieldsEqual(xdb, xtable, xfield, ydb, ytable, yfield string) bool {
|
||||
}
|
||||
|
||||
// CloneFieldByName clones a ResultField in ResultFields according to name.
|
||||
func CloneFieldByName(name string, fields []*ResultField, flag uint32) (*ResultField, error) {
|
||||
indices := GetResultFieldIndex(name, fields, flag)
|
||||
func CloneFieldByName(name string, fields []*ResultField) (*ResultField, error) {
|
||||
indices := GetResultFieldIndex(name, fields)
|
||||
if len(indices) == 0 {
|
||||
return nil, errors.Errorf("unknown field %s", name)
|
||||
}
|
||||
|
||||
@ -16,7 +16,6 @@ package field_test
|
||||
import (
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/column"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/field"
|
||||
"github.com/pingcap/tidb/model"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
@ -95,10 +94,6 @@ func (*testResultFieldSuite) TestMain(c *C) {
|
||||
c.Assert(rs[2].Tp, Equals, mysql.TypeVarString)
|
||||
c.Assert(rs[3].Tp, Equals, mysql.TypeBlob)
|
||||
|
||||
// For CheckAmbiguousField
|
||||
err := field.CheckAmbiguousField("c1", rs, field.OrgFieldNameFlag)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
col4 := column.Col{
|
||||
ColumnInfo: model.ColumnInfo{
|
||||
FieldType: *types.NewFieldType(mysql.TypeVarchar),
|
||||
@ -113,51 +108,20 @@ func (*testResultFieldSuite) TestMain(c *C) {
|
||||
DBName: "test",
|
||||
}
|
||||
rs = []*field.ResultField{r, r1, r2}
|
||||
// r1 and r2 are ambiguous: same column name but different table names
|
||||
err = field.CheckAmbiguousField("c2", rs, field.OrgFieldNameFlag)
|
||||
c.Assert(err, NotNil)
|
||||
// r1 and r2 with different alias name
|
||||
err = field.CheckAmbiguousField("c2", rs, field.FieldNameFlag)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// For CloneFieldByName
|
||||
_, err = field.CloneFieldByName("cx", rs, field.OrgFieldNameFlag)
|
||||
_, err := field.CloneFieldByName("cx", rs)
|
||||
c.Assert(err, NotNil)
|
||||
_, err = field.CloneFieldByName("c2", rs, field.OrgFieldNameFlag)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// For check all fields name
|
||||
names := []string{"cx"}
|
||||
err = field.CheckAllFieldNames(names, rs, field.OrgFieldNameFlag)
|
||||
c.Assert(err, NotNil)
|
||||
names = []string{"c1"}
|
||||
err = field.CheckAllFieldNames(names, rs, field.OrgFieldNameFlag)
|
||||
_, err = field.CloneFieldByName("c2", rs)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// For ContainAllFieldNames
|
||||
names = []string{"cx", "c2"}
|
||||
b := field.ContainAllFieldNames(names, rs, field.OrgFieldNameFlag)
|
||||
names := []string{"cx", "c2"}
|
||||
b := field.ContainAllFieldNames(names, rs)
|
||||
c.Assert(b, IsFalse)
|
||||
names = []string{"c2", "c1"}
|
||||
b = field.ContainAllFieldNames(names, rs, field.OrgFieldNameFlag)
|
||||
b = field.ContainAllFieldNames(names, rs)
|
||||
c.Assert(b, IsTrue)
|
||||
|
||||
// For GetFieldIndex
|
||||
f1 := &field.Field{
|
||||
Expr: &expression.Ident{CIStr: model.NewCIStr("c1")},
|
||||
AsName: "a",
|
||||
}
|
||||
f2 := &field.Field{
|
||||
Expr: &expression.Ident{CIStr: model.NewCIStr("c2")},
|
||||
AsName: "a",
|
||||
}
|
||||
fs := []*field.Field{f1, f2}
|
||||
idxs := field.GetFieldIndex("c1", fs, field.OrgFieldNameFlag)
|
||||
c.Assert(idxs, HasLen, 1)
|
||||
|
||||
idxs = field.GetFieldIndex("a", fs, field.FieldNameFlag)
|
||||
c.Assert(idxs, HasLen, 2)
|
||||
|
||||
}
|
||||
|
||||
func (*testResultFieldSuite) TestCheckWildcard(c *C) {
|
||||
|
||||
@ -238,7 +238,7 @@ func (r *TableDefaultPlan) filter(ctx context.Context, expr expression.Expressio
|
||||
colNames := expression.MentionedColumns(expr)
|
||||
// make sure all mentioned column names are in Fields
|
||||
// if not, e.g. the expr has two table like t1.c1 = t2.c2, we can't use filter
|
||||
if !field.ContainAllFieldNames(colNames, r.Fields, field.DefaultFieldFlag) {
|
||||
if !field.ContainAllFieldNames(colNames, r.Fields) {
|
||||
return r, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,8 +48,8 @@ func isTableOrIndex(p plan.Plan) bool {
|
||||
}
|
||||
|
||||
// GetIdentValue is a function that evaluate identifier value from row.
|
||||
func GetIdentValue(name string, fields []*field.ResultField, row []interface{}, flag uint32) (interface{}, error) {
|
||||
indices := field.GetResultFieldIndex(name, fields, flag)
|
||||
func GetIdentValue(name string, fields []*field.ResultField, row []interface{}) (interface{}, error) {
|
||||
indices := field.GetResultFieldIndex(name, fields)
|
||||
if len(indices) == 0 {
|
||||
return nil, errors.Errorf("unknown field %s", name)
|
||||
}
|
||||
|
||||
@ -180,7 +180,7 @@ func getIdentValueFromOuterQuery(ctx context.Context, name string) (interface{},
|
||||
|
||||
// first try to get from outer table reference.
|
||||
if t.FromData != nil {
|
||||
v, err = GetIdentValue(name, t.FromDataFields, t.FromData, field.DefaultFieldFlag)
|
||||
v, err = GetIdentValue(name, t.FromDataFields, t.FromData)
|
||||
if err == nil {
|
||||
// tell current subquery using outer query
|
||||
subquery.SetOuterQueryUsed(ctx)
|
||||
@ -190,7 +190,7 @@ func getIdentValueFromOuterQuery(ctx context.Context, name string) (interface{},
|
||||
|
||||
// then try to get from outer select list.
|
||||
if t.OutData != nil {
|
||||
v, err = GetIdentValue(name, t.OutDataFields, t.OutData, field.FieldNameFlag)
|
||||
v, err = GetIdentValue(name, t.OutDataFields, t.OutData)
|
||||
if err == nil {
|
||||
// tell current subquery using outer query
|
||||
subquery.SetOuterQueryUsed(ctx)
|
||||
|
||||
@ -225,7 +225,7 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie
|
||||
// TODO: use fromIdentVisitor to cleanup.
|
||||
var result *field.ResultField
|
||||
for _, name := range names {
|
||||
idx := field.GetResultFieldIndex(name, srcFields, field.DefaultFieldFlag)
|
||||
idx := field.GetResultFieldIndex(name, srcFields)
|
||||
if len(idx) > 1 {
|
||||
return nil, errors.Errorf("ambiguous field %s", name)
|
||||
}
|
||||
@ -238,7 +238,7 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie
|
||||
|
||||
if _, ok := v.Expr.(*expression.Ident); ok {
|
||||
// Field is ident.
|
||||
if result, err = field.CloneFieldByName(name, srcFields, field.DefaultFieldFlag); err != nil {
|
||||
if result, err = field.CloneFieldByName(name, srcFields); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
|
||||
}
|
||||
|
||||
// first find this identifier in FROM.
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields)
|
||||
if len(idx) > 0 {
|
||||
i.ReferScope = expression.IdentReferFromTable
|
||||
i.ReferIndex = idx[0]
|
||||
|
||||
@ -66,7 +66,7 @@ func (v *havingVisitor) visitIdentInAggregate(i *expression.Ident) (expression.E
|
||||
// then in select list, and outer query finally.
|
||||
|
||||
// find this identifier in FROM.
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields)
|
||||
if len(idx) > 0 {
|
||||
i.ReferScope = expression.IdentReferFromTable
|
||||
i.ReferIndex = idx[0]
|
||||
@ -115,8 +115,7 @@ func (v *havingVisitor) checkIdentInGroupBy(i *expression.Ident) (*expression.Id
|
||||
|
||||
// having is unqualified name, e.g, select * from t1, t2 group by t1.c having c.
|
||||
// both t1 and t2 have column c, we must check ambiguous here.
|
||||
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields)
|
||||
if len(idx) > 1 {
|
||||
return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i)
|
||||
}
|
||||
|
||||
@ -129,7 +129,7 @@ type fromIdentVisitor struct {
|
||||
}
|
||||
|
||||
func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) {
|
||||
idx := field.GetResultFieldIndex(i.L, v.fromFields, field.DefaultFieldFlag)
|
||||
idx := field.GetResultFieldIndex(i.L, v.fromFields)
|
||||
if len(idx) == 1 {
|
||||
i.ReferScope = expression.IdentReferFromTable
|
||||
i.ReferIndex = idx[0]
|
||||
|
||||
@ -126,7 +126,7 @@ func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression,
|
||||
}
|
||||
|
||||
// find this identifier in FROM.
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag)
|
||||
idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields)
|
||||
if len(idx) > 0 {
|
||||
i.ReferScope = expression.IdentReferFromTable
|
||||
i.ReferIndex = idx[0]
|
||||
|
||||
@ -107,7 +107,7 @@ func getUpdateColumns(assignList []expression.Assignment, fields []*field.Result
|
||||
name = fmt.Sprintf("%s.%s", v.TableName, v.ColName)
|
||||
}
|
||||
// use result fields to check assign list, otherwise use origin table columns
|
||||
idx := field.GetResultFieldIndex(name, fields, field.DefaultFieldFlag)
|
||||
idx := field.GetResultFieldIndex(name, fields)
|
||||
if n := len(idx); n > 1 {
|
||||
return nil, errors.Errorf("ambiguous field %s", name)
|
||||
} else if n == 0 {
|
||||
@ -276,7 +276,7 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
|
||||
|
||||
// Set EvalIdentFunc
|
||||
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
|
||||
return plans.GetIdentValue(name, p.GetFields(), rowData, field.DefaultFieldFlag)
|
||||
return plans.GetIdentValue(name, p.GetFields(), rowData)
|
||||
}
|
||||
|
||||
// Update rows
|
||||
|
||||
Reference in New Issue
Block a user