diff --git a/ast/ddl.go b/ast/ddl.go index 5bb658f172..9a8cd44752 100644 --- a/ast/ddl.go +++ b/ast/ddl.go @@ -18,21 +18,22 @@ import ( ) var ( - _ DDLNode = &CreateDatabaseStmt{} - _ DDLNode = &DropDatabaseStmt{} - _ DDLNode = &CreateTableStmt{} - _ DDLNode = &DropTableStmt{} - _ DDLNode = &CreateIndexStmt{} - _ DDLNode = &DropTableStmt{} _ DDLNode = &AlterTableStmt{} + _ DDLNode = &CreateDatabaseStmt{} + _ DDLNode = &CreateIndexStmt{} + _ DDLNode = &CreateTableStmt{} + _ DDLNode = &DropDatabaseStmt{} + _ DDLNode = &DropIndexStmt{} + _ DDLNode = &DropTableStmt{} _ DDLNode = &TruncateTableStmt{} - _ Node = &IndexColName{} - _ Node = &ReferenceDef{} - _ Node = &ColumnOption{} - _ Node = &Constraint{} - _ Node = &ColumnDef{} - _ Node = &ColumnPosition{} - _ Node = &AlterTableSpec{} + + _ Node = &AlterTableSpec{} + _ Node = &ColumnDef{} + _ Node = &ColumnOption{} + _ Node = &ColumnPosition{} + _ Node = &Constraint{} + _ Node = &IndexColName{} + _ Node = &ReferenceDef{} ) // CharsetOpt is used for parsing charset option from SQL. @@ -53,8 +54,6 @@ const ( // DatabaseOption represents database option. type DatabaseOption struct { - node - Tp DatabaseOptionType Value string } @@ -71,11 +70,11 @@ type CreateDatabaseStmt struct { // Accept implements Node Accept interface. func (n *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*CreateDatabaseStmt) + n = newNode.(*CreateDatabaseStmt) return v.Leave(n) } @@ -90,11 +89,11 @@ type DropDatabaseStmt struct { // Accept implements Node Accept interface. func (n *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*DropDatabaseStmt) + n = newNode.(*DropDatabaseStmt) return v.Leave(n) } @@ -108,11 +107,11 @@ type IndexColName struct { // Accept implements Node Accept interface. func (n *IndexColName) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*IndexColName) + n = newNode.(*IndexColName) node, ok := n.Column.Accept(v) if !ok { return n, false @@ -132,11 +131,11 @@ type ReferenceDef struct { // Accept implements Node Accept interface. func (n *ReferenceDef) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ReferenceDef) + n = newNode.(*ReferenceDef) node, ok := n.Table.Accept(v) if !ok { return n, false @@ -184,11 +183,11 @@ type ColumnOption struct { // Accept implements Node Accept interface. func (n *ColumnOption) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ColumnOption) + n = newNode.(*ColumnOption) if n.Expr != nil { node, ok := n.Expr.Accept(v) if !ok { @@ -231,11 +230,11 @@ type Constraint struct { // Accept implements Node Accept interface. func (n *Constraint) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*Constraint) + n = newNode.(*Constraint) for i, val := range n.Keys { node, ok := val.Accept(v) if !ok { @@ -264,11 +263,11 @@ type ColumnDef struct { // Accept implements Node Accept interface. func (n *ColumnDef) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ColumnDef) + n = newNode.(*ColumnDef) node, ok := n.Name.Accept(v) if !ok { return n, false @@ -298,11 +297,11 @@ type CreateTableStmt struct { // Accept implements Node Accept interface. func (n *CreateTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*CreateTableStmt) + n = newNode.(*CreateTableStmt) node, ok := n.Table.Accept(v) if !ok { return n, false @@ -336,11 +335,11 @@ type DropTableStmt struct { // Accept implements Node Accept interface. func (n *DropTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*DropTableStmt) + n = newNode.(*DropTableStmt) for i, val := range n.Tables { node, ok := val.Accept(v) if !ok { @@ -364,11 +363,11 @@ type CreateIndexStmt struct { // Accept implements Node Accept interface. func (n *CreateIndexStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*CreateIndexStmt) + n = newNode.(*CreateIndexStmt) node, ok := n.Table.Accept(v) if !ok { return n, false @@ -396,11 +395,11 @@ type DropIndexStmt struct { // Accept implements Node Accept interface. func (n *DropIndexStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*DropIndexStmt) + n = newNode.(*DropIndexStmt) node, ok := n.Table.Accept(v) if !ok { return n, false @@ -459,11 +458,11 @@ type ColumnPosition struct { // Accept implements Node Accept interface. func (n *ColumnPosition) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ColumnPosition) + n = newNode.(*ColumnPosition) if n.RelativeColumn != nil { node, ok := n.RelativeColumn.Accept(v) if !ok { @@ -505,11 +504,11 @@ type AlterTableSpec struct { // Accept implements Node Accept interface. func (n *AlterTableSpec) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*AlterTableSpec) + n = newNode.(*AlterTableSpec) if n.Constraint != nil { node, ok := n.Constraint.Accept(v) if !ok { @@ -552,11 +551,11 @@ type AlterTableStmt struct { // Accept implements Node Accept interface. func (n *AlterTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*AlterTableStmt) + n = newNode.(*AlterTableStmt) node, ok := n.Table.Accept(v) if !ok { return n, false @@ -582,11 +581,11 @@ type TruncateTableStmt struct { // Accept implements Node Accept interface. func (n *TruncateTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*TruncateTableStmt) + n = newNode.(*TruncateTableStmt) node, ok := n.Table.Accept(v) if !ok { return n, false diff --git a/ast/dml.go b/ast/dml.go index e15b5f83bb..1402cf9699 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -18,18 +18,28 @@ import ( ) var ( - _ DMLNode = &InsertStmt{} _ DMLNode = &DeleteStmt{} + _ DMLNode = &InsertStmt{} + _ DMLNode = &UnionStmt{} _ DMLNode = &UpdateStmt{} _ DMLNode = &SelectStmt{} - _ DMLNode = &UnionStmt{} - _ Node = &Join{} - _ Node = &TableName{} - _ Node = &TableSource{} - _ Node = &Assignment{} - _ Node = &Limit{} - _ Node = &WildCardField{} - _ Node = &SelectField{} + _ DMLNode = &ShowStmt{} + + _ Node = &Assignment{} + _ Node = &ByItem{} + _ Node = &FieldList{} + _ Node = &GroupByClause{} + _ Node = &HavingClause{} + _ Node = &Join{} + _ Node = &Limit{} + _ Node = &OnCondition{} + _ Node = &OrderByClause{} + _ Node = &SelectField{} + _ Node = &TableName{} + _ Node = &TableRefsClause{} + _ Node = &TableSource{} + _ Node = &UnionClause{} + _ Node = &WildCardField{} ) // JoinType is join type, including cross/left/right/full. @@ -61,11 +71,11 @@ type Join struct { // Accept implements Node Accept interface. func (n *Join) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*Join) + n = newNode.(*Join) node, ok := n.Left.Accept(v) if !ok { return n, false @@ -102,38 +112,11 @@ type TableName struct { // Accept implements Node Accept interface. func (n *TableName) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*TableName) - return v.Leave(n) -} - -// TableSource represents table source with a name. -type TableSource struct { - node - - // Source is the source of the data, can be a TableName, - // a SelectStmt, a UnionStmt, or a JoinNode. - Source ResultSetNode - - // AsName is the as name of the table source. - AsName model.CIStr -} - -// Accept implements Node Accept interface. -func (n *TableSource) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNod) - } - n = newNod.(*TableSource) - node, ok := n.Source.Accept(v) - if !ok { - return n, false - } - n.Source = node.(ResultSetNode) + n = newNode.(*TableName) return v.Leave(n) } @@ -146,11 +129,11 @@ type OnCondition struct { // Accept implements Node Accept interface. func (n *OnCondition) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*OnCondition) + n = newNode.(*OnCondition) node, ok := n.Expr.Accept(v) if !ok { return n, false @@ -159,12 +142,39 @@ func (n *OnCondition) Accept(v Visitor) (Node, bool) { return v.Leave(n) } -// SetResultFields implements ResultSet interface. +// TableSource represents table source with a name. +type TableSource struct { + node + + // Source is the source of the data, can be a TableName, + // a SelectStmt, a UnionStmt, or a JoinNode. + Source ResultSetNode + + // AsName is the alias name of the table source. + AsName model.CIStr +} + +// Accept implements Node Accept interface. +func (n *TableSource) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*TableSource) + node, ok := n.Source.Accept(v) + if !ok { + return n, false + } + n.Source = node.(ResultSetNode) + return v.Leave(n) +} + +// SetResultFields implements ResultSetNode interface. func (n *TableSource) SetResultFields(rfs []*ResultField) { n.Source.SetResultFields(rfs) } -// GetResultFields implements ResultSet interface. +// GetResultFields implements ResultSetNode interface. func (n *TableSource) GetResultFields() []*ResultField { return n.Source.GetResultFields() } @@ -189,11 +199,11 @@ type WildCardField struct { // Accept implements Node Accept interface. func (n *WildCardField) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*WildCardField) + n = newNode.(*WildCardField) return v.Leave(n) } @@ -209,17 +219,17 @@ type SelectField struct { WildCard *WildCardField // If Expr is not nil, WildCard will be nil. Expr ExprNode - // AsName name for Expr. + // Alias name for Expr. AsName model.CIStr } // Accept implements Node Accept interface. func (n *SelectField) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*SelectField) + n = newNode.(*SelectField) if n.Expr != nil { node, ok := n.Expr.Accept(v) if !ok { @@ -239,11 +249,11 @@ type FieldList struct { // Accept implements Node Accept interface. func (n *FieldList) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*FieldList) + n = newNode.(*FieldList) for i, val := range n.Fields { node, ok := val.Accept(v) if !ok { @@ -263,11 +273,11 @@ type TableRefsClause struct { // Accept implements Node Accept interface. func (n *TableRefsClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*TableRefsClause) + n = newNode.(*TableRefsClause) node, ok := n.TableRefs.Accept(v) if !ok { return n, false @@ -286,11 +296,11 @@ type ByItem struct { // Accept implements Node Accept interface. func (n *ByItem) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ByItem) + n = newNode.(*ByItem) node, ok := n.Expr.Accept(v) if !ok { return n, false @@ -307,11 +317,11 @@ type GroupByClause struct { // Accept implements Node Accept interface. func (n *GroupByClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*GroupByClause) + n = newNode.(*GroupByClause) for i, val := range n.Items { node, ok := val.Accept(v) if !ok { @@ -330,11 +340,11 @@ type HavingClause struct { // Accept implements Node Accept interface. func (n *HavingClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*HavingClause) + n = newNode.(*HavingClause) node, ok := n.Expr.Accept(v) if !ok { return n, false @@ -352,11 +362,11 @@ type OrderByClause struct { // Accept implements Node Accept interface. func (n *OrderByClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*OrderByClause) + n = newNode.(*OrderByClause) for i, val := range n.Items { node, ok := val.Accept(v) if !ok { @@ -395,12 +405,12 @@ type SelectStmt struct { // Accept implements Node Accept interface. func (n *SelectStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*SelectStmt) + n = newNode.(*SelectStmt) if n.From != nil { node, ok := n.From.Accept(v) if !ok { @@ -456,6 +466,7 @@ func (n *SelectStmt) Accept(v Visitor) (Node, bool) { } n.Limit = node.(*Limit) } + return v.Leave(n) } @@ -469,11 +480,11 @@ type UnionClause struct { // Accept implements Node Accept interface. func (n *UnionClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*UnionClause) + n = newNode.(*UnionClause) node, ok := n.Select.Accept(v) if !ok { return n, false @@ -496,11 +507,11 @@ type UnionStmt struct { // Accept implements Node Accept interface. func (n *UnionStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*UnionStmt) + n = newNode.(*UnionStmt) for i, val := range n.Selects { node, ok := val.Accept(v) if !ok { @@ -536,11 +547,11 @@ type Assignment struct { // Accept implements Node Accept interface. func (n *Assignment) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*Assignment) + n = newNode.(*Assignment) node, ok := n.Column.Accept(v) if !ok { return n, false @@ -580,12 +591,12 @@ type InsertStmt struct { // Accept implements Node Accept interface. func (n *InsertStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*InsertStmt) + n = newNode.(*InsertStmt) if n.Select != nil { node, ok := n.Select.Accept(v) if !ok { @@ -593,6 +604,7 @@ func (n *InsertStmt) Accept(v Visitor) (Node, bool) { } n.Select = node.(ResultSetNode) } + node, ok := n.Table.Accept(v) if !ok { return n, false @@ -653,12 +665,12 @@ type DeleteStmt struct { // Accept implements Node Accept interface. func (n *DeleteStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*DeleteStmt) + n = newNode.(*DeleteStmt) node, ok := n.TableRefs.Accept(v) if !ok { return n, false @@ -714,11 +726,11 @@ type UpdateStmt struct { // Accept implements Node Accept interface. func (n *UpdateStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*UpdateStmt) + n = newNode.(*UpdateStmt) node, ok := n.TableRefs.Accept(v) if !ok { return n, false @@ -765,10 +777,90 @@ type Limit struct { // Accept implements Node Accept interface. func (n *Limit) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) + } + n = newNode.(*Limit) + return v.Leave(n) +} + +// ShowStmtType is the type for SHOW statement. +type ShowStmtType int + +// Show statement types. +const ( + ShowNone = iota + ShowEngines + ShowDatabases + ShowTables + ShowTableStatus + ShowColumns + ShowWarnings + ShowCharset + ShowVariables + ShowStatus + ShowCollation + ShowCreateTable + ShowGrants + ShowTriggers + ShowProcedureStatus + ShowIndex +) + +// ShowStmt is a statement to provide information about databases, tables, columns and so on. +// See: https://dev.mysql.com/doc/refman/5.7/en/show.html +type ShowStmt struct { + dmlNode + + Tp ShowStmtType // Databases/Tables/Columns/.... + DBName string + Table *TableName // Used for showing columns. + Column *ColumnName // Used for `desc table column`. + Flag int // Some flag parsed from sql, such as FULL. + Full bool + User string // Used for show grants. + + // Used by show variables + GlobalScope bool + Pattern *PatternLikeExpr + Where ExprNode +} + +// Accept implements Node Accept interface. +func (n *ShowStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*ShowStmt) + if n.Table != nil { + node, ok := n.Table.Accept(v) + if !ok { + return n, false + } + n.Table = node.(*TableName) + } + if n.Column != nil { + node, ok := n.Column.Accept(v) + if !ok { + return n, false + } + n.Column = node.(*ColumnName) + } + if n.Pattern != nil { + node, ok := n.Pattern.Accept(v) + if !ok { + return n, false + } + n.Pattern = node.(*PatternLikeExpr) + } + if n.Where != nil { + node, ok := n.Where.Accept(v) + if !ok { + return n, false + } + n.Where = node.(ExprNode) } - n = newNod.(*Limit) return v.Leave(n) } diff --git a/ast/expressions.go b/ast/expressions.go index adc0ac74e9..7d55241338 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -23,29 +23,30 @@ import ( ) var ( - _ ExprNode = &ValueExpr{} _ ExprNode = &BetweenExpr{} _ ExprNode = &BinaryOperationExpr{} - _ Node = &WhenClause{} _ ExprNode = &CaseExpr{} - _ ExprNode = &SubqueryExpr{} - _ ExprNode = &CompareSubqueryExpr{} - _ Node = &ColumnName{} _ ExprNode = &ColumnNameExpr{} + _ ExprNode = &CompareSubqueryExpr{} _ ExprNode = &DefaultExpr{} _ ExprNode = &ExistsSubqueryExpr{} - _ ExprNode = &PatternInExpr{} _ ExprNode = &IsNullExpr{} _ ExprNode = &IsTruthExpr{} - _ ExprNode = &PatternLikeExpr{} _ ExprNode = &ParamMarkerExpr{} _ ExprNode = &ParenthesesExpr{} - _ ExprNode = &PositionExpr{} + _ ExprNode = &PatternInExpr{} + _ ExprNode = &PatternLikeExpr{} _ ExprNode = &PatternRegexpExpr{} + _ ExprNode = &PositionExpr{} _ ExprNode = &RowExpr{} + _ ExprNode = &SubqueryExpr{} _ ExprNode = &UnaryOperationExpr{} + _ ExprNode = &ValueExpr{} _ ExprNode = &ValuesExpr{} _ ExprNode = &VariableExpr{} + + _ Node = &ColumnName{} + _ Node = &WhenClause{} ) // ValueExpr is the simple value expression. @@ -69,11 +70,11 @@ func NewValueExpr(value interface{}) *ValueExpr { // Accept implements Node interface. func (n *ValueExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ValueExpr) + n = newNode.(*ValueExpr) return v.Leave(n) } @@ -92,12 +93,12 @@ type BetweenExpr struct { // Accept implements Node interface. func (n *BetweenExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*BetweenExpr) + n = newNode.(*BetweenExpr) node, ok := n.Expr.Accept(v) if !ok { return n, false @@ -119,7 +120,7 @@ func (n *BetweenExpr) Accept(v Visitor) (Node, bool) { return v.Leave(n) } -// BinaryOperationExpr is for binary operation like 1 + 1, 1 - 1, etc. +// BinaryOperationExpr is for binary operation like `1 + 1`, `1 - 1`, etc. type BinaryOperationExpr struct { exprNode // Op is the operator code for BinaryOperation. @@ -132,12 +133,12 @@ type BinaryOperationExpr struct { // Accept implements Node interface. func (n *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*BinaryOperationExpr) + n = newNode.(*BinaryOperationExpr) node, ok := n.L.Accept(v) if !ok { return n, false @@ -164,11 +165,12 @@ type WhenClause struct { // Accept implements Node Accept interface. func (n *WhenClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*WhenClause) + + n = newNode.(*WhenClause) node, ok := n.Expr.Accept(v) if !ok { return n, false @@ -196,11 +198,12 @@ type CaseExpr struct { // Accept implements Node Accept interface. func (n *CaseExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*CaseExpr) + + n = newNode.(*CaseExpr) if n.Value != nil { node, ok := n.Value.Accept(v) if !ok { @@ -234,11 +237,11 @@ type SubqueryExpr struct { // Accept implements Node Accept interface. func (n *SubqueryExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*SubqueryExpr) + n = newNode.(*SubqueryExpr) node, ok := n.Query.Accept(v) if !ok { return n, false @@ -247,12 +250,12 @@ func (n *SubqueryExpr) Accept(v Visitor) (Node, bool) { return v.Leave(n) } -// SetResultFields implements ResultSet interface. +// SetResultFields implements ResultSetNode interface. func (n *SubqueryExpr) SetResultFields(rfs []*ResultField) { n.Query.SetResultFields(rfs) } -// GetResultFields implements ResultSet interface. +// GetResultFields implements ResultSetNode interface. func (n *SubqueryExpr) GetResultFields() []*ResultField { return n.Query.GetResultFields() } @@ -275,11 +278,11 @@ type CompareSubqueryExpr struct { // Accept implements Node Accept interface. func (n *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*CompareSubqueryExpr) + n = newNode.(*CompareSubqueryExpr) node, ok := n.L.Accept(v) if !ok { return n, false @@ -303,11 +306,11 @@ type ColumnName struct { // Accept implements Node Accept interface. func (n *ColumnName) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ColumnName) + n = newNode.(*ColumnName) return v.Leave(n) } @@ -325,11 +328,11 @@ type ColumnNameExpr struct { // Accept implements Node Accept interface. func (n *ColumnNameExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ColumnNameExpr) + n = newNode.(*ColumnNameExpr) node, ok := n.Name.Accept(v) if !ok { return n, false @@ -347,11 +350,11 @@ type DefaultExpr struct { // Accept implements Node Accept interface. func (n *DefaultExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*DefaultExpr) + n = newNode.(*DefaultExpr) if n.Name != nil { node, ok := n.Name.Accept(v) if !ok { @@ -372,11 +375,11 @@ type ExistsSubqueryExpr struct { // Accept implements Node Accept interface. func (n *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ExistsSubqueryExpr) + n = newNode.(*ExistsSubqueryExpr) node, ok := n.Sel.Accept(v) if !ok { return n, false @@ -400,11 +403,11 @@ type PatternInExpr struct { // Accept implements Node Accept interface. func (n *PatternInExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*PatternInExpr) + n = newNode.(*PatternInExpr) node, ok := n.Expr.Accept(v) if !ok { return n, false @@ -438,11 +441,11 @@ type IsNullExpr struct { // Accept implements Node Accept interface. func (n *IsNullExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*IsNullExpr) + n = newNode.(*IsNullExpr) node, ok := n.Expr.Accept(v) if !ok { return n, false @@ -464,11 +467,11 @@ type IsTruthExpr struct { // Accept implements Node Accept interface. func (n *IsTruthExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*IsTruthExpr) + n = newNode.(*IsTruthExpr) node, ok := n.Expr.Accept(v) if !ok { return n, false @@ -495,11 +498,11 @@ type PatternLikeExpr struct { // Accept implements Node Accept interface. func (n *PatternLikeExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*PatternLikeExpr) + n = newNode.(*PatternLikeExpr) if n.Expr != nil { node, ok := n.Expr.Accept(v) if !ok { @@ -526,11 +529,11 @@ type ParamMarkerExpr struct { // Accept implements Node Accept interface. func (n *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ParamMarkerExpr) + n = newNode.(*ParamMarkerExpr) return v.Leave(n) } @@ -543,11 +546,11 @@ type ParenthesesExpr struct { // Accept implements Node Accept interface. func (n *ParenthesesExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ParenthesesExpr) + n = newNode.(*ParenthesesExpr) if n.Expr != nil { node, ok := n.Expr.Accept(v) if !ok { @@ -571,11 +574,11 @@ type PositionExpr struct { // Accept implements Node Accept interface. func (n *PositionExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*PositionExpr) + n = newNode.(*PositionExpr) return v.Leave(n) } @@ -597,11 +600,11 @@ type PatternRegexpExpr struct { // Accept implements Node Accept interface. func (n *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*PatternRegexpExpr) + n = newNode.(*PatternRegexpExpr) node, ok := n.Expr.Accept(v) if !ok { return n, false @@ -625,11 +628,11 @@ type RowExpr struct { // Accept implements Node Accept interface. func (n *RowExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*RowExpr) + n = newNode.(*RowExpr) for i, val := range n.Values { node, ok := val.Accept(v) if !ok { @@ -651,11 +654,11 @@ type UnaryOperationExpr struct { // Accept implements Node Accept interface. func (n *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*UnaryOperationExpr) + n = newNode.(*UnaryOperationExpr) node, ok := n.V.Accept(v) if !ok { return n, false @@ -673,11 +676,11 @@ type ValuesExpr struct { // Accept implements Node Accept interface. func (n *ValuesExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ValuesExpr) + n = newNode.(*ValuesExpr) node, ok := n.Column.Accept(v) if !ok { return n, false @@ -693,16 +696,16 @@ type VariableExpr struct { Name string // IsGlobal indicates whether this variable is global. IsGlobal bool - // IsSystem indicates whether this variable is a global variable in current session. + // IsSystem indicates whether this variable is a system variable in current session. IsSystem bool } // Accept implements Node Accept interface. func (n *VariableExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*VariableExpr) + n = newNode.(*VariableExpr) return v.Leave(n) } diff --git a/ast/functions.go b/ast/functions.go index 0a96398382..ff253ca0ed 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -19,15 +19,16 @@ import ( ) var ( - _ FuncNode = &FuncCallExpr{} - _ FuncNode = &FuncExtractExpr{} - _ FuncNode = &FuncConvertExpr{} - _ FuncNode = &FuncCastExpr{} - _ FuncNode = &FuncSubstringExpr{} - _ FuncNode = &FuncLocateExpr{} - _ FuncNode = &FuncTrimExpr{} - _ FuncNode = &FuncDateArithExpr{} _ FuncNode = &AggregateFuncExpr{} + _ FuncNode = &FuncCallExpr{} + _ FuncNode = &FuncCastExpr{} + _ FuncNode = &FuncConvertExpr{} + _ FuncNode = &FuncDateArithExpr{} + _ FuncNode = &FuncExtractExpr{} + _ FuncNode = &FuncLocateExpr{} + _ FuncNode = &FuncSubstringExpr{} + _ FuncNode = &FuncSubstringIndexExpr{} + _ FuncNode = &FuncTrimExpr{} ) // UnquoteString is not quoted when printed. @@ -44,11 +45,11 @@ type FuncCallExpr struct { // Accept implements Node interface. func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*FuncCallExpr) + n = newNode.(*FuncCallExpr) for i, val := range n.Args { node, ok := val.Accept(v) if !ok { @@ -70,11 +71,11 @@ type FuncExtractExpr struct { // Accept implements Node Accept interface. func (n *FuncExtractExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*FuncExtractExpr) + n = newNode.(*FuncExtractExpr) node, ok := n.Date.Accept(v) if !ok { return n, false @@ -95,11 +96,11 @@ type FuncConvertExpr struct { // Accept implements Node Accept interface. func (n *FuncConvertExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*FuncConvertExpr) + n = newNode.(*FuncConvertExpr) node, ok := n.Expr.Accept(v) if !ok { return n, false @@ -132,11 +133,11 @@ type FuncCastExpr struct { // Accept implements Node Accept interface. func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*FuncCastExpr) + n = newNode.(*FuncCastExpr) node, ok := n.Expr.Accept(v) if !ok { return n, false @@ -157,11 +158,11 @@ type FuncSubstringExpr struct { // Accept implements Node Accept interface. func (n *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*FuncSubstringExpr) + n = newNode.(*FuncSubstringExpr) node, ok := n.StrExpr.Accept(v) if !ok { return n, false @@ -194,11 +195,11 @@ type FuncSubstringIndexExpr struct { // Accept implements Node Accept interface. func (n *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*FuncSubstringIndexExpr) + n = newNode.(*FuncSubstringIndexExpr) node, ok := n.StrExpr.Accept(v) if !ok { return n, false @@ -229,11 +230,11 @@ type FuncLocateExpr struct { // Accept implements Node Accept interface. func (n *FuncLocateExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*FuncLocateExpr) + n = newNode.(*FuncLocateExpr) node, ok := n.Str.Accept(v) if !ok { return n, false @@ -278,11 +279,11 @@ type FuncTrimExpr struct { // Accept implements Node Accept interface. func (n *FuncTrimExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*FuncTrimExpr) + n = newNode.(*FuncTrimExpr) node, ok := n.Str.Accept(v) if !ok { return n, false @@ -330,11 +331,11 @@ type FuncDateArithExpr struct { // Accept implements Node Accept interface. func (n *FuncDateArithExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*FuncDateArithExpr) + n = newNode.(*FuncDateArithExpr) if n.Date != nil { node, ok := n.Date.Accept(v) if !ok { @@ -367,11 +368,11 @@ type AggregateFuncExpr struct { // Accept implements Node Accept interface. func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*AggregateFuncExpr) + n = newNode.(*AggregateFuncExpr) for i, val := range n.Args { node, ok := val.Accept(v) if !ok { diff --git a/ast/misc.go b/ast/misc.go index 4480df0c43..5b847333c7 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -16,27 +16,27 @@ package ast import "github.com/pingcap/tidb/mysql" var ( - _ StmtNode = &ExplainStmt{} - _ StmtNode = &PrepareStmt{} - _ StmtNode = &DeallocateStmt{} - _ StmtNode = &ExecuteStmt{} - _ StmtNode = &ShowStmt{} _ StmtNode = &BeginStmt{} _ StmtNode = &CommitStmt{} + _ StmtNode = &CreateUserStmt{} + _ StmtNode = &DeallocateStmt{} + _ StmtNode = &DoStmt{} + _ StmtNode = &ExecuteStmt{} + _ StmtNode = &ExplainStmt{} + _ StmtNode = &GrantStmt{} + _ StmtNode = &PrepareStmt{} _ StmtNode = &RollbackStmt{} - _ StmtNode = &UseStmt{} - _ StmtNode = &SetStmt{} _ StmtNode = &SetCharsetStmt{} _ StmtNode = &SetPwdStmt{} - _ StmtNode = &CreateUserStmt{} - _ StmtNode = &DoStmt{} - _ StmtNode = &GrantStmt{} + _ StmtNode = &SetStmt{} + _ StmtNode = &UseStmt{} + _ Node = &PrivElem{} _ Node = &VariableAssignment{} ) // FloatOpt is used for parsing floating-point type option from SQL. -// TODO: add reference doc. +// See: http://dev.mysql.com/doc/refman/5.7/en/floating-point-types.html type FloatOpt struct { Flen int Decimal int @@ -62,11 +62,11 @@ type ExplainStmt struct { // Accept implements Node Accept interface. func (n *ExplainStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ExplainStmt) + n = newNode.(*ExplainStmt) node, ok := n.Stmt.Accept(v) if !ok { return n, false @@ -88,11 +88,11 @@ type PrepareStmt struct { // Accept implements Node Accept interface. func (n *PrepareStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*PrepareStmt) + n = newNode.(*PrepareStmt) if n.SQLVar != nil { node, ok := n.SQLVar.Accept(v) if !ok { @@ -113,11 +113,11 @@ type DeallocateStmt struct { // Accept implements Node Accept interface. func (n *DeallocateStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*DeallocateStmt) + n = newNode.(*DeallocateStmt) return v.Leave(n) } @@ -132,11 +132,11 @@ type ExecuteStmt struct { // Accept implements Node Accept interface. func (n *ExecuteStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*ExecuteStmt) + n = newNode.(*ExecuteStmt) for i, val := range n.UsingVars { node, ok := val.Accept(v) if !ok { @@ -147,86 +147,6 @@ func (n *ExecuteStmt) Accept(v Visitor) (Node, bool) { return v.Leave(n) } -// ShowStmtType is the type for SHOW statement. -type ShowStmtType int - -// Show statement types. -const ( - ShowNone = iota - ShowEngines - ShowDatabases - ShowTables - ShowTableStatus - ShowColumns - ShowWarnings - ShowCharset - ShowVariables - ShowStatus - ShowCollation - ShowCreateTable - ShowGrants - ShowTriggers - ShowProcedureStatus - ShowIndex -) - -// ShowStmt is a statement to provide information about databases, tables, columns and so on. -// See: https://dev.mysql.com/doc/refman/5.7/en/show.html -type ShowStmt struct { - dmlNode - - Tp ShowStmtType // Databases/Tables/Columns/.... - DBName string - Table *TableName // Used for showing columns. - Column *ColumnName // Used for `desc table column`. - Flag int // Some flag parsed from sql, such as FULL. - Full bool - User string // Used for show grants. - - // Used by show variables - GlobalScope bool - Pattern *PatternLikeExpr - Where ExprNode -} - -// Accept implements Node Accept interface. -func (n *ShowStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNod) - } - n = newNod.(*ShowStmt) - if n.Table != nil { - node, ok := n.Table.Accept(v) - if !ok { - return n, false - } - n.Table = node.(*TableName) - } - if n.Column != nil { - node, ok := n.Column.Accept(v) - if !ok { - return n, false - } - n.Column = node.(*ColumnName) - } - if n.Pattern != nil { - node, ok := n.Pattern.Accept(v) - if !ok { - return n, false - } - n.Pattern = node.(*PatternLikeExpr) - } - if n.Where != nil { - node, ok := n.Where.Accept(v) - if !ok { - return n, false - } - n.Where = node.(ExprNode) - } - return v.Leave(n) -} - // BeginStmt is a statement to start a new transaction. // See: https://dev.mysql.com/doc/refman/5.7/en/commit.html type BeginStmt struct { @@ -235,11 +155,11 @@ type BeginStmt struct { // Accept implements Node Accept interface. func (n *BeginStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*BeginStmt) + n = newNode.(*BeginStmt) return v.Leave(n) } @@ -251,11 +171,11 @@ type CommitStmt struct { // Accept implements Node Accept interface. func (n *CommitStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*CommitStmt) + n = newNode.(*CommitStmt) return v.Leave(n) } @@ -267,11 +187,11 @@ type RollbackStmt struct { // Accept implements Node Accept interface. func (n *RollbackStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*RollbackStmt) + n = newNode.(*RollbackStmt) return v.Leave(n) } @@ -285,11 +205,11 @@ type UseStmt struct { // Accept implements Node Accept interface. func (n *UseStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*UseStmt) + n = newNode.(*UseStmt) return v.Leave(n) } @@ -304,11 +224,11 @@ type VariableAssignment struct { // Accept implements Node interface. func (n *VariableAssignment) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*VariableAssignment) + n = newNode.(*VariableAssignment) node, ok := n.Value.Accept(v) if !ok { return n, false @@ -326,11 +246,11 @@ type SetStmt struct { // Accept implements Node Accept interface. func (n *SetStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*SetStmt) + n = newNode.(*SetStmt) for i, val := range n.Variables { node, ok := val.Accept(v) if !ok { @@ -352,11 +272,11 @@ type SetCharsetStmt struct { // Accept implements Node Accept interface. func (n *SetCharsetStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*SetCharsetStmt) + n = newNode.(*SetCharsetStmt) return v.Leave(n) } @@ -371,11 +291,11 @@ type SetPwdStmt struct { // Accept implements Node Accept interface. func (n *SetPwdStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*SetPwdStmt) + n = newNode.(*SetPwdStmt) return v.Leave(n) } @@ -396,11 +316,11 @@ type CreateUserStmt struct { // Accept implements Node Accept interface. func (n *CreateUserStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*CreateUserStmt) + n = newNode.(*CreateUserStmt) return v.Leave(n) } @@ -413,11 +333,11 @@ type DoStmt struct { // Accept implements Node Accept interface. func (n *DoStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*DoStmt) + n = newNode.(*DoStmt) for i, val := range n.Exprs { node, ok := val.Accept(v) if !ok { @@ -467,17 +387,18 @@ func (n *AdminStmt) Accept(v Visitor) (Node, bool) { // PrivElem is the privilege type and optional column list. type PrivElem struct { node + Priv mysql.PrivilegeType Cols []*ColumnName } // Accept implements Node Accept interface. func (n *PrivElem) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*PrivElem) + n = newNode.(*PrivElem) for i, val := range n.Cols { node, ok := val.Accept(v) if !ok { @@ -531,11 +452,11 @@ type GrantStmt struct { // Accept implements Node Accept interface. func (n *GrantStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(n) + newNode, skipChildren := v.Enter(n) if skipChildren { - return v.Leave(newNod) + return v.Leave(newNode) } - n = newNod.(*GrantStmt) + n = newNode.(*GrantStmt) for i, val := range n.Privs { node, ok := val.Accept(v) if !ok { diff --git a/ddl/ddl.go b/ddl/ddl.go index 4f0523f382..f4997c9ac8 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -434,6 +434,8 @@ func (d *ddl) buildTableInfo(tableName model.CIStr, cols []*column.Col, constrai switch col.Tp { case mysql.TypeLong, mysql.TypeLonglong: tbInfo.PKIsHandle = true + // Avoid creating index for PK handle column. + continue } } } diff --git a/driver.go b/driver.go index e4411021cd..54b820fa04 100644 --- a/driver.go +++ b/driver.go @@ -153,7 +153,7 @@ func parseDriverDSN(dsn string) (storePath, dbName string, err error) { // Examples: // goleveldb://relative/path/test // boltdb:///absolute/path/test -// hbase://zk1,zk2,zk3/hbasetbl/test?tso=127.0.0.1:1234 +// hbase://zk1,zk2,zk3/hbasetbl/test?tso=zk // // Open may return a cached connection (one previously closed), but doing so is // unnecessary; the sql package maintains a pool of idle connections for diff --git a/executor/adapter.go b/executor/adapter.go index 90fe9e4cd2..d5c27d40ac 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -27,7 +27,7 @@ import ( "github.com/pingcap/tidb/util/types" ) -// adapter wraps a executor, implements rset.Recordset interface +// adapter wraps an executor, implements rset.Recordset interface type recordsetAdapter struct { fields []*field.ResultField executor Executor @@ -42,12 +42,12 @@ func (a *recordsetAdapter) Fields() ([]*field.ResultField, error) { } func (a *recordsetAdapter) FirstRow() ([]interface{}, error) { - ro, err := a.Next() + row, err := a.Next() a.Close() - if ro == nil || err != nil { + if err != nil || row == nil { return nil, errors.Trace(err) } - return ro.Data, nil + return row.Data, nil } func (a *recordsetAdapter) Rows(limit, offset int) ([][]interface{}, error) { @@ -56,7 +56,7 @@ func (a *recordsetAdapter) Rows(limit, offset int) ([][]interface{}, error) { // Move to offset. for offset > 0 { row, err := a.Next() - if row == nil || err != nil { + if err != nil || row == nil { return nil, errors.Trace(err) } offset-- @@ -80,7 +80,7 @@ func (a *recordsetAdapter) Rows(limit, offset int) ([][]interface{}, error) { func (a *recordsetAdapter) Next() (*oplan.Row, error) { row, err := a.executor.Next() - if row == nil || err != nil { + if err != nil || row == nil { return nil, errors.Trace(err) } oRow := &oplan.Row{ @@ -159,12 +159,9 @@ func (a *statementAdapter) Exec(ctx context.Context) (rset.Recordset, error) { defer e.Close() for { row, err := e.Next() - if err != nil { + if err != nil || row == nil { return nil, errors.Trace(err) } - if row == nil { - return nil, nil - } } } return &recordsetAdapter{ diff --git a/executor/builder.go b/executor/builder.go index 9003ca5364..511fc5a247 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -14,6 +14,8 @@ package executor import ( + "math" + "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" @@ -76,9 +78,11 @@ func (b *executorBuilder) build(p plan.Plan) Executor { func (b *executorBuilder) buildTableScan(v *plan.TableScan) Executor { table, _ := b.is.TableByID(v.Table.ID) return &TableScanExec{ - t: table, - fields: v.Fields(), - ctx: b.ctx, + t: table, + fields: v.Fields(), + ctx: b.ctx, + ranges: v.Ranges, + seekHandle: math.MinInt64, } } @@ -99,8 +103,7 @@ func (b *executorBuilder) buildCheckTable(v *plan.CheckTable) Executor { func (b *executorBuilder) buildIndexScan(v *plan.IndexScan) Executor { tbl, _ := b.is.TableByID(v.Table.ID) var idx *column.IndexedCol - indices := tbl.Indices() - for _, val := range indices { + for _, val := range tbl.Indices() { if val.IndexInfo.Name.L == v.Index.Name.L { idx = val break @@ -128,14 +131,14 @@ func (b *executorBuilder) buildIndexScan(v *plan.IndexScan) Executor { } func (b *executorBuilder) buildIndexRange(scan *IndexScanExec, v *plan.IndexRange) *IndexRangeExec { - rang := &IndexRangeExec{ + ran := &IndexRangeExec{ scan: scan, lowVals: v.LowVal, lowExclude: v.LowExclude, highVals: v.HighVal, highExclude: v.HighExclude, } - return rang + return ran } func (b *executorBuilder) joinConditions(conditions []ast.ExprNode) ast.ExprNode { diff --git a/executor/compiler.go b/executor/compiler.go index 4efedcc109..eaec04f20d 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -46,11 +46,11 @@ func (c *Compiler) Compile(ctx context.Context, node ast.StmtNode) (stmt.Stateme if err != nil { return nil, errors.Trace(err) } - a := &statementAdapter{ + sa := &statementAdapter{ is: is, plan: p, } - return a, nil + return sa, nil } c.converter = &converter.Converter{} s, err := c.converter.Convert(node) diff --git a/executor/executor.go b/executor/executor.go index a2c481950a..8ffcc59073 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -31,18 +31,18 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/terror" - "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/types" ) var ( - _ Executor = &TableScanExec{} - _ Executor = &IndexScanExec{} - _ Executor = &IndexRangeExec{} - _ Executor = &SelectFieldsExec{} _ Executor = &FilterExec{} + _ Executor = &IndexRangeExec{} + _ Executor = &IndexScanExec{} _ Executor = &LimitExec{} + _ Executor = &SelectFieldsExec{} + _ Executor = &SelectLockExec{} _ Executor = &SortExec{} + _ Executor = &TableScanExec{} ) // Error instances. @@ -190,10 +190,13 @@ func (e *CheckTable) Close() error { // TableScanExec represents a table scan executor. type TableScanExec struct { - t table.Table - fields []*ast.ResultField - iter kv.Iterator - ctx context.Context + t table.Table + fields []*ast.ResultField + iter kv.Iterator + ctx context.Context + ranges []plan.TableRange // Disjoint close handle ranges. + seekHandle int64 // The handle to seek, should be initialized to math.MinInt64. + cursor int // The range cursor, used to locate to current range. } // Fields implements Executor Fields interface. @@ -203,36 +206,89 @@ func (e *TableScanExec) Fields() []*ast.ResultField { // Next implements Execution Next interface. func (e *TableScanExec) Next() (*Row, error) { - if e.iter == nil { - txn, err := e.ctx.GetTxn(false) + for { + if e.cursor >= len(e.ranges) { + return nil, nil + } + ran := e.ranges[e.cursor] + if e.seekHandle < ran.LowVal { + e.seekHandle = ran.LowVal + } + if e.seekHandle > ran.HighVal { + e.cursor++ + continue + } + rowKey, err := e.seek() + if err != nil || rowKey == nil { + return nil, errors.Trace(err) + } + handle, err := tables.DecodeRecordKeyHandle(rowKey) if err != nil { return nil, errors.Trace(err) } - e.iter, err = txn.Seek(e.t.FirstKey()) + if handle > ran.HighVal { + // The handle is out of the current range, but may be in following ranges. + // We seek to the range that may contains the handle, so we + // don't need to seek key again. + inRange := e.seekRange(handle) + if !inRange { + // The handle may be less than the current range low value, can not + // return directly. + continue + } + } + row, err := e.getRow(handle, rowKey) if err != nil { return nil, errors.Trace(err) } + e.seekHandle = handle + 1 + return row, nil } - if !e.iter.Valid() || !e.iter.Key().HasPrefix(e.t.RecordPrefix()) { - return nil, nil - } - // TODO: check if lock valid - // the record layout in storage (key -> value): - // r1 -> lock-version - // r1_col1 -> r1 col1 value - // r1_col2 -> r1 col2 value - // r2 -> lock-version - // r2_col1 -> r2 col1 value - // r2_col2 -> r2 col2 value - // ... - rowKey := e.iter.Key() - handle, err := tables.DecodeRecordKeyHandle(rowKey) +} + +func (e *TableScanExec) seek() (kv.Key, error) { + seekKey := tables.EncodeRecordKey(e.t.TableID(), e.seekHandle, 0) + txn, err := e.ctx.GetTxn(false) if err != nil { return nil, errors.Trace(err) } + if e.iter != nil { + e.iter.Close() + } + e.iter, err = txn.Seek(seekKey) + if err != nil { + return nil, errors.Trace(err) + } + if !e.iter.Valid() || !e.iter.Key().HasPrefix(e.t.RecordPrefix()) { + // No more records in the table, skip to the end. + e.cursor = len(e.ranges) + return nil, nil + } + return e.iter.Key(), nil +} - // TODO: we could just fetch mentioned columns' values +// seekRange increments the range cursor to the range +// with high value greater or equal to handle. +func (e *TableScanExec) seekRange(handle int64) (inRange bool) { + for { + e.cursor++ + if e.cursor >= len(e.ranges) { + return false + } + ran := e.ranges[e.cursor] + if handle < ran.LowVal { + return false + } + if handle > ran.HighVal { + continue + } + return true + } +} + +func (e *TableScanExec) getRow(handle int64, rowKey kv.Key) (*Row, error) { row := &Row{} + var err error row.Data, err = e.t.Row(e.ctx, handle) if err != nil { return nil, errors.Trace(err) @@ -248,12 +304,6 @@ func (e *TableScanExec) Next() (*Row, error) { Key: string(rowKey), } row.RowKeys = append(row.RowKeys, rke) - - rk := e.t.RecordKey(handle, nil) - err = kv.NextUntil(e.iter, util.RowKeyPrefixFilter(rk)) - if err != nil { - return nil, errors.Trace(err) - } return row, nil } @@ -311,6 +361,7 @@ func (e *IndexRangeExec) Next() (*Row, error) { return nil, types.EOFAsNil(err) } } + for { if e.finished { return nil, nil @@ -464,8 +515,7 @@ func (e *IndexScanExec) Next() (*Row, error) { // Close implements Executor Close interface. func (e *IndexScanExec) Close() error { for e.rangeIdx < len(e.Ranges) { - ran := e.Ranges[e.rangeIdx] - ran.Close() + e.Ranges[e.rangeIdx].Close() e.rangeIdx++ } return nil @@ -649,7 +699,7 @@ func (e *LimitExec) Close() error { return e.Src.Close() } -// orderByRow bind a row to its order values, so it can be sorted. +// orderByRow binds a row to its order values, so it can be sorted. type orderByRow struct { key []interface{} row *Row @@ -709,9 +759,6 @@ func (e *SortExec) Less(i, j int) bool { // Next implements Executor Next interface. func (e *SortExec) Next() (*Row, error) { - if e.err != nil { - return nil, errors.Trace(e.err) - } if !e.fetched { for { srcRow, err := e.Src.Next() @@ -736,6 +783,9 @@ func (e *SortExec) Next() (*Row, error) { sort.Sort(e) e.fetched = true } + if e.err != nil { + return nil, errors.Trace(e.err) + } if e.Idx >= len(e.Rows) { return nil, nil } diff --git a/executor/executor_test.go b/executor/executor_test.go index 07b823b6cd..1884f94e99 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -117,3 +117,65 @@ func (s *testSuite) TestPrepared(c *C) { exec.Next() exec.Close() } + +func (s *testSuite) TestTablePKisHandleScan(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int PRIMARY KEY AUTO_INCREMENT)") + tk.MustExec("insert t values (),()") + tk.MustExec("insert t values (-100),(0)") + + cases := []struct { + sql string + result [][]interface{} + }{ + { + "select * from t", + testkit.Rows("-100", "0", "1", "2"), + }, + { + "select * from t where a = 1", + testkit.Rows("1"), + }, + { + "select * from t where a != 1", + testkit.Rows("-100", "0", "2"), + }, + { + "select * from t where a >= '1.1'", + testkit.Rows("2"), + }, + { + "select * from t where a < '1.1'", + testkit.Rows("-100", "0", "1"), + }, + { + "select * from t where a > '-100.1' and a < 2", + testkit.Rows("-100", "0", "1"), + }, + { + "select * from t where a is null", + testkit.Rows(), + }, { + "select * from t where a is true", + testkit.Rows("-100", "1", "2"), + }, { + "select * from t where a is false", + testkit.Rows("0"), + }, + { + "select * from t where a in (0, 2)", + testkit.Rows("0", "2"), + }, + { + "select * from t where a between 0 and 1", + testkit.Rows("0", "1"), + }, + } + + for _, ca := range cases { + result := tk.MustQuery(ca.sql) + result.Check(ca.result) + } +} diff --git a/executor/prepared.go b/executor/prepared.go index e3f5db11fb..b52e48c474 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -31,6 +31,12 @@ import ( "github.com/pingcap/tidb/stmt" ) +var ( + _ Executor = &DeallocateExec{} + _ Executor = &ExecuteExec{} + _ Executor = &PrepareExec{} +) + type paramMarkerSorter struct { markers []*ast.ParamMarkerExpr } @@ -277,9 +283,9 @@ func CompileExecutePreparedStmt(ctx context.Context, ID uint32, args ...interfac for i, val := range args { execPlan.UsingVars[i] = ast.NewValueExpr(val) } - a := &statementAdapter{ + sa := &statementAdapter{ is: sessionctx.GetDomain(ctx).InfoSchema(), plan: execPlan, } - return a + return sa } diff --git a/kv/kv.go b/kv/kv.go index f0505f4a06..dcafc8dbc9 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -19,6 +19,9 @@ const ( // transaction's commit. // This option is an optimization for frequent checks during a transaction, e.g. batch inserts. PresumeKeyNotExists Option = iota + 1 + // PresumeKeyNotExistsError is the option key for error. + // When PresumeKeyNotExists is set and condition is not match, should throw the error. + PresumeKeyNotExistsError ) // Retriever is the interface wraps the basic Get and Seek methods. diff --git a/kv/txn.go b/kv/txn.go index 197a6fc5c5..18515380bf 100644 --- a/kv/txn.go +++ b/kv/txn.go @@ -38,6 +38,7 @@ func RunInNewTxn(store Storage, retryable bool, f func(txn Transaction) error) e continue } if err != nil { + txn.Rollback() return errors.Trace(err) } diff --git a/kv/union_store.go b/kv/union_store.go index 80615ca58d..b5a093e088 100644 --- a/kv/union_store.go +++ b/kv/union_store.go @@ -50,12 +50,20 @@ var ( }) ) +// conditionPair is used to store lazy check condition. +// If condition not match (value is not equal as expected one), returns err. +type conditionPair struct { + key Key + value []byte + err error +} + // UnionStore is an in-memory Store which contains a buffer for write and a // snapshot for read. type unionStore struct { *BufferStore - snapshot Snapshot // for read - lazyConditionPairs MemBuffer // for delay check + snapshot Snapshot // for read + lazyConditionPairs map[string](*conditionPair) // for delay check opts options } @@ -64,7 +72,7 @@ func NewUnionStore(snapshot Snapshot) UnionStore { return &unionStore{ BufferStore: NewBufferStore(snapshot), snapshot: snapshot, - lazyConditionPairs: &lazyMemBuffer{}, + lazyConditionPairs: make(map[string](*conditionPair)), opts: make(map[Option]interface{}), } } @@ -121,9 +129,11 @@ func (us *unionStore) Get(k Key) ([]byte, error) { v, err := us.MemBuffer.Get(k) if IsErrNotFound(err) { if _, ok := us.opts.Get(PresumeKeyNotExists); ok { - err = us.markLazyConditionPair(k, nil) - if err != nil { - return nil, errors.Trace(err) + e, ok := us.opts.Get(PresumeKeyNotExistsError) + if ok && e != nil { + us.markLazyConditionPair(k, nil, e.(error)) + } else { + us.markLazyConditionPair(k, nil, ErrKeyExists) } return nil, errors.Trace(ErrNotExist) } @@ -141,45 +151,36 @@ func (us *unionStore) Get(k Key) ([]byte, error) { } // markLazyConditionPair marks a kv pair for later check. -func (us *unionStore) markLazyConditionPair(k Key, v []byte) error { - if len(v) == 0 { - return errors.Trace(us.lazyConditionPairs.Delete(k)) +// If condition not match, should return e as error. +func (us *unionStore) markLazyConditionPair(k Key, v []byte, e error) { + us.lazyConditionPairs[string(k)] = &conditionPair{ + key: k.Clone(), + value: v, + err: e, } - return errors.Trace(us.lazyConditionPairs.Set(k, v)) } // CheckLazyConditionPairs implements the UnionStore interface. func (us *unionStore) CheckLazyConditionPairs() error { - var keys []Key - it, err := us.lazyConditionPairs.Seek(nil) - if err != nil { - return errors.Trace(err) - } - for ; it.Valid(); it.Next() { - keys = append(keys, it.Key().Clone()) - } - it.Close() - - if len(keys) == 0 { + if len(us.lazyConditionPairs) == 0 { return nil } + keys := make([]Key, 0, len(us.lazyConditionPairs)) + for _, v := range us.lazyConditionPairs { + keys = append(keys, v.key) + } values, err := us.snapshot.BatchGet(keys) if err != nil { return errors.Trace(err) } - it, err = us.lazyConditionPairs.Seek(nil) - if err != nil { - return errors.Trace(err) - } - defer it.Close() - for ; it.Valid(); it.Next() { - keyStr := string(it.Key()) - if len(it.Value()) == 0 { - if _, exist := values[keyStr]; exist { - return errors.Trace(ErrKeyExists) + + for k, v := range us.lazyConditionPairs { + if len(v.value) == 0 { + if _, exist := values[k]; exist { + return errors.Trace(v.err) } } else { - if bytes.Compare(values[keyStr], it.Value()) != 0 { + if bytes.Compare(values[k], v.value) != 0 { return errors.Trace(ErrLazyConditionPairsNotMatch) } } @@ -201,7 +202,6 @@ func (us *unionStore) DelOption(opt Option) { func (us *unionStore) Release() { us.snapshot.Release() us.BufferStore.Release() - us.lazyConditionPairs.Release() } type options map[Option]interface{} diff --git a/kv/union_store_test.go b/kv/union_store_test.go index fabb744908..9e08db414f 100644 --- a/kv/union_store_test.go +++ b/kv/union_store_test.go @@ -91,7 +91,8 @@ func (s *testUnionStoreSuite) TestLazyConditionCheck(c *C) { c.Assert(err, IsNil) c.Assert(v, BytesEquals, []byte("1")) - s.us.SetOption(PresumeKeyNotExists, 1) + s.us.SetOption(PresumeKeyNotExists, nil) + s.us.SetOption(PresumeKeyNotExistsError, ErrNotExist) _, err = s.us.Get([]byte("2")) c.Assert(terror.ErrorEqual(err, ErrNotExist), IsTrue) diff --git a/optimizer/evaluator/evaluator.go b/optimizer/evaluator/evaluator.go index 2dc312850c..d8d87909d2 100644 --- a/optimizer/evaluator/evaluator.go +++ b/optimizer/evaluator/evaluator.go @@ -61,6 +61,7 @@ func EvalBool(ctx context.Context, expr ast.ExprNode) (bool, error) { if val == nil { return false, nil } + i, err := types.ToBool(val) if err != nil { return false, errors.Trace(err) @@ -75,7 +76,7 @@ func boolToInt64(v bool) int64 { return int64(0) } -// Evaluator is a ast Visitor that evaluates an expression. +// Evaluator is an ast Visitor that evaluates an expression. type Evaluator struct { ctx context.Context err error @@ -193,7 +194,7 @@ func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool { for _, val := range v.WhenClauses { cmp, err := types.Compare(target, val.Expr.GetValue()) if err != nil { - e.err = err + e.err = errors.Trace(err) return false } if cmp == 0 { @@ -239,7 +240,7 @@ func (e *Evaluator) checkInList(not bool, in interface{}, list []interface{}) (i r, err := types.Compare(in, v) if err != nil { - return nil, err + return nil, errors.Trace(err) } if r == 0 { @@ -270,7 +271,7 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool { } r, err := types.Compare(n.Expr.GetValue(), v.GetValue()) if err != nil { - e.err = err + e.err = errors.Trace(err) return false } if r == 0 { @@ -306,7 +307,7 @@ func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool { if !types.IsNil(val) { ival, err := types.ToBool(val) if err != nil { - e.err = err + e.err = errors.Trace(err) return false } if ival == v.True { @@ -359,7 +360,7 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { case opcode.Not: n, err := types.ToBool(a) if err != nil { - e.err = err + e.err = errors.Trace(err) } else if n == 0 { u.SetValue(int64(1)) } else { @@ -369,7 +370,7 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { // for bit operation, we will use int64 first, then return uint64 n, err := types.ToInt64(a) if err != nil { - e.err = err + e.err = errors.Trace(err) return false } u.SetValue(uint64(^n)) @@ -462,14 +463,14 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { u.SetValue(mysql.ZeroDecimal.Sub(x.ToNumber())) case string: f, err := types.StrToFloat(x) - e.err = err + e.err = errors.Trace(err) u.SetValue(-f) case mysql.Decimal: f, _ := x.Float64() u.SetValue(mysql.NewDecimalFromFloat(-f)) case []byte: f, err := types.StrToFloat(string(x)) - e.err = err + e.err = errors.Trace(err) u.SetValue(-f) case mysql.Hex: u.SetValue(-x.ToNumber()) diff --git a/optimizer/evaluator/evaluator_binop.go b/optimizer/evaluator/evaluator_binop.go index 9035a9611d..fb3e10ea75 100644 --- a/optimizer/evaluator/evaluator_binop.go +++ b/optimizer/evaluator/evaluator_binop.go @@ -96,7 +96,7 @@ func (e *Evaluator) handleOrOr(o *ast.BinaryOperationExpr) bool { if !types.IsNil(leftVal) { x, err := types.ToBool(leftVal) if err != nil { - e.err = err + e.err = errors.Trace(err) return false } else if x == 1 { // true || any other types is true. @@ -137,7 +137,7 @@ func (e *Evaluator) handleXor(o *ast.BinaryOperationExpr) bool { y, err := types.ToBool(righVal) if err != nil { - e.err = err + e.err = errors.Trace(err) return false } if x == y { @@ -167,13 +167,13 @@ func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool { n, err := types.Compare(a, b) if err != nil { - e.err = err + e.err = errors.Trace(err) return false } r, err := getCompResult(o.Op, n) if err != nil { - e.err = err + e.err = errors.Trace(err) return false } if r { @@ -215,13 +215,13 @@ func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool { x, err := types.ToInt64(a) if err != nil { - e.err = err + e.err = errors.Trace(err) return false } y, err := types.ToInt64(b) if err != nil { - e.err = err + e.err = errors.Trace(err) return false } @@ -256,8 +256,8 @@ func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool { e.err = errors.Trace(err) return false } - a, b = types.Coerce(a, b) + a, b = types.Coerce(a, b) if a == nil || b == nil { o.SetValue(nil) return true @@ -385,7 +385,7 @@ func computeDiv(a, b interface{}) (interface{}, error) { case float64: y, err := types.ToFloat64(b) if err != nil { - return nil, err + return nil, errors.Trace(err) } if y == 0 { @@ -399,12 +399,12 @@ func computeDiv(a, b interface{}) (interface{}, error) { // we will use 4 here xa, err := types.ToDecimal(a) if err != nil { - return nil, err + return nil, errors.Trace(err) } xb, err := types.ToDecimal(b) if err != nil { - return nil, err + return nil, errors.Trace(err) } if f, _ := xb.Float64(); f == 0 { // division by zero return null @@ -505,12 +505,12 @@ func computeIntDiv(a, b interface{}) (interface{}, error) { // if any is none integer, use decimal to calculate x, err := types.ToDecimal(a) if err != nil { - return nil, err + return nil, errors.Trace(err) } y, err := types.ToDecimal(b) if err != nil { - return nil, err + return nil, errors.Trace(err) } if f, _ := y.Float64(); f == 0 { @@ -526,9 +526,9 @@ func coerceArithmetic(a interface{}) (interface{}, error) { // MySQL will convert string to float for arithmetic operation f, err := types.StrToFloat(x) if err != nil { - return nil, err + return nil, errors.Trace(err) } - return f, err + return f, errors.Trace(err) case mysql.Time: // if time has no precision, return int64 v := x.ToNumber() @@ -547,9 +547,9 @@ func coerceArithmetic(a interface{}) (interface{}, error) { // []byte is the same as string, converted to float for arithmetic operator. f, err := types.StrToFloat(string(x)) if err != nil { - return nil, err + return nil, errors.Trace(err) } - return f, err + return f, errors.Trace(err) case mysql.Hex: return x.ToNumber(), nil case mysql.Bit: diff --git a/optimizer/evaluator/evaluator_like.go b/optimizer/evaluator/evaluator_like.go index b64768e73e..afb49b2376 100644 --- a/optimizer/evaluator/evaluator_like.go +++ b/optimizer/evaluator/evaluator_like.go @@ -14,10 +14,11 @@ package evaluator import ( + "regexp" + "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/util/types" - "regexp" ) const ( @@ -199,7 +200,7 @@ func (e *Evaluator) patternRegexp(p *ast.PatternRegexpExpr) bool { } if re, err = regexp.Compile(spattern); err != nil { - e.err = err + e.err = errors.Trace(err) return false } diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 8f858757d5..d76fd4a01e 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -25,7 +25,7 @@ import ( "github.com/pingcap/tidb/terror" ) -// Optimize do optimization and create a Plan. +// Optimize does optimization and creates a Plan. // The node must be prepared first. func Optimize(ctx context.Context, node ast.Node) (plan.Plan, error) { // We have to inter type again because after parameter is set, the expression type may change. @@ -39,14 +39,21 @@ func Optimize(ctx context.Context, node ast.Node) (plan.Plan, error) { if err != nil { return nil, errors.Trace(err) } - bestCost := plan.EstimateCost(p) - bestPlan := p - alts, err := plan.Alternatives(p) if err != nil { return nil, errors.Trace(err) } + err = plan.Refine(p) + if err != nil { + return nil, errors.Trace(err) + } + bestCost := plan.EstimateCost(p) + bestPlan := p for _, alt := range alts { + err = plan.Refine(alt) + if err != nil { + return nil, errors.Trace(err) + } cost := plan.EstimateCost(alt) if cost < bestCost { bestCost = cost @@ -57,8 +64,8 @@ func Optimize(ctx context.Context, node ast.Node) (plan.Plan, error) { } // Prepare prepares a raw statement parsed from parser. -// The statement must be prepared before it can be passed to optimize function -// We pass InfoSchema instead of get from Context in case it is changed after resolving name. +// The statement must be prepared before it can be passed to optimize function. +// We pass InfoSchema instead of getting from Context in case it is changed after resolving name. func Prepare(is infoschema.InfoSchema, ctx context.Context, node ast.Node) error { if err := Validate(node, true); err != nil { return errors.Trace(err) diff --git a/optimizer/plan/alternatives.go b/optimizer/plan/alternatives.go index 94d6692234..caf6ed2419 100644 --- a/optimizer/plan/alternatives.go +++ b/optimizer/plan/alternatives.go @@ -16,7 +16,7 @@ package plan import "github.com/juju/errors" // Alternatives returns multiple alternative plans that -// can be picked base on their cost. +// can be picked based on their cost. func Alternatives(p Plan) ([]Plan, error) { var plans []Plan switch x := p.(type) { @@ -37,12 +37,6 @@ func Alternatives(p Plan) ([]Plan, error) { default: return nil, ErrUnsupportedType.Gen("Unknown plan %T", p) } - for _, val := range plans { - err := refine(val) - if err != nil { - return nil, errors.Trace(err) - } - } return plans, nil } @@ -54,19 +48,19 @@ func tableScanAlternatives(p *TableScan) []Plan { LowVal: []interface{}{nil}, HighVal: []interface{}{MaxVal}, } - ip := &IndexScan{ + is := &IndexScan{ Index: v, Table: p.Table, Ranges: []*IndexRange{fullRange}, } - ip.SetFields(p.Fields()) - alts = append(alts, ip) + is.SetFields(p.Fields()) + alts = append(alts, is) } return alts } // planWithSrcAlternatives shallow copies the WithSrcPlan, -// and set its src to src alternatives. +// and sets its src to src alternatives. func planWithSrcAlternatives(p WithSrcPlan) ([]Plan, error) { srcs, err := Alternatives(p.Src()) if err != nil { diff --git a/optimizer/plan/cost.go b/optimizer/plan/cost.go index 952e6c05c8..62d8d86105 100644 --- a/optimizer/plan/cost.go +++ b/optimizer/plan/cost.go @@ -43,13 +43,7 @@ func (c *costEstimator) Leave(p Plan) (Plan, bool) { case *IndexScan: c.indexScan(v) case *TableScan: - v.startupCost = 0 - if v.limit == 0 { - v.rowCount = FullRangeCount - } else { - v.rowCount = math.Min(FullRangeCount, v.limit) - } - v.totalCost = v.rowCount * RowCost + c.tableScan(v) case *SelectFields: if v.Src() != nil { v.startupCost = v.Src().StartupCost() @@ -71,7 +65,7 @@ func (c *costEstimator) Leave(p Plan) (Plan, bool) { v.rowCount = v.Src().RowCount() v.totalCost = v.Src().TotalCost() } else { - // Sort plan must retrieves all the rows before returns the first row. + // Sort plan must retrieve all the rows before returns the first row. v.startupCost = v.Src().TotalCost() + v.Src().RowCount()*SortCost if v.limit == 0 { v.rowCount = v.Src().RowCount() @@ -87,6 +81,41 @@ func (c *costEstimator) Leave(p Plan) (Plan, bool) { } return p, true } + +func (c *costEstimator) tableScan(v *TableScan) { + var rowCount float64 + if len(v.Ranges) == 1 && v.Ranges[0].LowVal == math.MinInt64 && v.Ranges[0].HighVal == math.MaxInt64 { + // full range use default row count. + rowCount = FullRangeCount + } else { + for _, v := range v.Ranges { + // for condition like 'a = 0'. + if v.LowVal == v.HighVal { + rowCount++ + continue + } + // For condition like 'a < 0'. + if v.LowVal == math.MinInt64 { + rowCount += HalfRangeCount + } + // For condition like 'a > 0'. + if v.HighVal == math.MaxInt64 { + rowCount += HalfRangeCount + } + // For condition like 'a > 0 and a < 1'. + rowCount += MiddleRangeCount + } + } + v.startupCost = 0 + if v.limit == 0 { + // limit is zero means no limit. + v.rowCount = rowCount + } else { + v.rowCount = math.Min(rowCount, v.limit) + } + v.totalCost = v.rowCount * RowCost +} + func (c *costEstimator) indexScan(v *IndexScan) { var rowCount float64 if len(v.Ranges) == 1 && v.Ranges[0].LowVal[0] == nil && v.Ranges[0].HighVal[0] == MaxVal { @@ -108,7 +137,7 @@ func (c *costEstimator) indexScan(v *IndexScan) { rowCount += HalfRangeCount } // For condition like 'a > 0 and a < 1'. - rowCount += 100 + rowCount += MiddleRangeCount } // If the index has too many ranges, the row count may exceed the default row count. // Make sure the cost is lower than full range. diff --git a/optimizer/plan/explainer.go b/optimizer/plan/explainer.go index 01d2bd7013..b3bffad4b5 100644 --- a/optimizer/plan/explainer.go +++ b/optimizer/plan/explainer.go @@ -15,6 +15,7 @@ package plan import ( "fmt" + "math" "strings" ) @@ -38,7 +39,16 @@ func (e *explainer) Leave(in Plan) (Plan, bool) { var str string switch x := in.(type) { case *TableScan: - str = fmt.Sprintf("Table(%s)", x.Table.Name.L) + if len(x.Ranges) > 0 { + ran := x.Ranges[0] + if ran.LowVal != math.MinInt64 || ran.HighVal != math.MaxInt64 { + str = fmt.Sprintf("Range(%s)", x.Table.Name.L) + } else { + str = fmt.Sprintf("Table(%s)", x.Table.Name.L) + } + } else { + str = fmt.Sprintf("Table(%s)", x.Table.Name.L) + } case *IndexScan: str = fmt.Sprintf("Index(%s.%s)", x.Table.Name.L, x.Index.Name.L) case *ShowDDL: diff --git a/optimizer/plan/plan_test.go b/optimizer/plan/plan_test.go index b7022a8c4d..315e0e8395 100644 --- a/optimizer/plan/plan_test.go +++ b/optimizer/plan/plan_test.go @@ -21,6 +21,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser" ) @@ -253,7 +254,7 @@ func (s *testPlanSuite) TestBuilder(c *C) { c.Assert(err, IsNil) explainStr, err := Explain(p) c.Assert(err, IsNil) - c.Assert(ca.planStr, Equals, explainStr) + c.Assert(explainStr, Equals, ca.planStr, Commentf("for expr %s", ca.sqlStr)) } } @@ -268,7 +269,7 @@ func (s *testPlanSuite) TestBestPlan(c *C) { }, { sql: "select * from t order by a", - best: "Index(t.a)->Fields", + best: "Table(t)->Fields", }, { sql: "select * from t where b = 1 order by a", @@ -291,8 +292,8 @@ func (s *testPlanSuite) TestBestPlan(c *C) { best: "Index(t.c_d)->Filter->Fields", }, { - sql: "select * from t where a like 'abc%'", - best: "Index(t.a)->Filter->Fields", + sql: "select * from t where b like 'abc%'", + best: "Index(t.b)->Filter->Fields", }, { sql: "select * from t where d", @@ -300,22 +301,29 @@ func (s *testPlanSuite) TestBestPlan(c *C) { }, { sql: "select * from t where a is null", - best: "Index(t.a)->Filter->Fields", + best: "Range(t)->Filter->Fields", }, } for _, ca := range cases { + comment := Commentf("for %s", ca.sql) s, err := parser.ParseOneStmt(ca.sql, "", "") - c.Assert(err, IsNil, Commentf("for expr %s", ca.sql)) + c.Assert(err, IsNil, comment) stmt := s.(*ast.SelectStmt) ast.SetFlag(stmt) mockResolve(stmt) + p, err := BuildPlan(stmt) c.Assert(err, IsNil) + alts, err := Alternatives(p) + c.Assert(err, IsNil) + + err = Refine(p) + c.Assert(err, IsNil) bestCost := EstimateCost(p) bestPlan := p - alts, err := Alternatives(p) - c.Assert(err, IsNil) + for _, alt := range alts { + c.Assert(Refine(alt), IsNil) cost := EstimateCost(alt) if cost < bestCost { bestCost = cost @@ -330,14 +338,6 @@ func (s *testPlanSuite) TestBestPlan(c *C) { func mockResolve(node ast.Node) { indices := []*model.IndexInfo{ - { - Name: model.NewCIStr("a"), - Columns: []*model.IndexColumn{ - { - Name: model.NewCIStr("a"), - }, - }, - }, { Name: model.NewCIStr("b"), Columns: []*model.IndexColumn{ @@ -358,9 +358,15 @@ func mockResolve(node ast.Node) { }, }, } + pkColumn := &model.ColumnInfo{ + Name: model.NewCIStr("a"), + } + pkColumn.Flag = mysql.PriKeyFlag table := &model.TableInfo{ - Indices: indices, - Name: model.NewCIStr("t"), + Columns: []*model.ColumnInfo{pkColumn}, + Indices: indices, + Name: model.NewCIStr("t"), + PKIsHandle: true, } resolver := mockResolver{table: table} node.Accept(&resolver) @@ -383,6 +389,9 @@ func (b *mockResolver) Leave(in ast.Node) (ast.Node, bool) { }, Table: b.table, } + if x.Name.Name.L == "a" { + x.Refer.Column = b.table.Columns[0] + } case *ast.TableName: x.TableInfo = b.table } diff --git a/optimizer/plan/planbuilder.go b/optimizer/plan/planbuilder.go index 18ec8e8970..933092b18f 100644 --- a/optimizer/plan/planbuilder.go +++ b/optimizer/plan/planbuilder.go @@ -14,6 +14,8 @@ package plan import ( + "math" + "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/model" @@ -39,11 +41,7 @@ const ( func BuildPlan(node ast.Node) (Plan, error) { var builder planBuilder p := builder.build(node) - if builder.err != nil { - return nil, builder.err - } - err := refine(p) - return p, err + return p, builder.err } // planBuilder builds Plan from an ast.Node. @@ -132,7 +130,8 @@ func (b *planBuilder) buildJoin(from *ast.Join) Plan { return nil } p := &TableScan{ - Table: tn.TableInfo, + Table: tn.TableInfo, + Ranges: []TableRange{{math.MinInt64, math.MaxInt64}}, } p.SetFields(tn.GetResultFields()) return p diff --git a/optimizer/plan/plans.go b/optimizer/plan/plans.go index 2a2f27acaf..df8f1235b3 100644 --- a/optimizer/plan/plans.go +++ b/optimizer/plan/plans.go @@ -18,12 +18,19 @@ import ( "github.com/pingcap/tidb/model" ) +// TableRange represents a range of row handle. +type TableRange struct { + LowVal int64 + HighVal int64 +} + // TableScan represents a table scan plan. type TableScan struct { basePlan - Table *model.TableInfo - Desc bool + Table *model.TableInfo + Desc bool + Ranges []TableRange } // Accept implements Plan Accept interface. diff --git a/optimizer/plan/range.go b/optimizer/plan/range.go index 77077b6c92..b24bba42fd 100644 --- a/optimizer/plan/range.go +++ b/optimizer/plan/range.go @@ -15,8 +15,10 @@ package plan import ( "fmt" + "math" "sort" + "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" @@ -425,3 +427,50 @@ func (r *rangeBuilder) appendIndexRange(origin *IndexRange, rangePoints []rangeP } return newRanges } + +func (r *rangeBuilder) buildTableRanges(rangePoints []rangePoint) []TableRange { + tableRanges := make([]TableRange, 0, len(rangePoints)/2) + for i := 0; i < len(rangePoints); i += 2 { + startPoint := rangePoints[i] + if startPoint.value == nil || startPoint.value == MinNotNullVal { + startPoint.value = math.MinInt64 + } + startInt, err := types.ToInt64(startPoint.value) + if err != nil { + r.err = errors.Trace(err) + return tableRanges + } + cmp, err := types.Compare(startInt, startPoint.value) + if err != nil { + r.err = errors.Trace(err) + return tableRanges + } + if cmp < 0 || (cmp == 0 && startPoint.excl) { + startInt++ + } + endPoint := rangePoints[i+1] + if endPoint.value == nil { + endPoint.value = math.MinInt64 + } else if endPoint.value == MaxVal { + endPoint.value = math.MaxInt64 + } + endInt, err := types.ToInt64(endPoint.value) + if err != nil { + r.err = errors.Trace(err) + return tableRanges + } + cmp, err = types.Compare(endInt, endPoint.value) + if err != nil { + r.err = errors.Trace(err) + return tableRanges + } + if cmp > 0 || (cmp == 0 && endPoint.excl) { + endInt-- + } + if startInt > endInt { + continue + } + tableRanges = append(tableRanges, TableRange{LowVal: startInt, HighVal: endInt}) + } + return tableRanges +} diff --git a/optimizer/plan/refiner.go b/optimizer/plan/refiner.go index 6b8cc7b149..9775fa9d86 100644 --- a/optimizer/plan/refiner.go +++ b/optimizer/plan/refiner.go @@ -16,21 +16,23 @@ package plan import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" ) -func refine(p Plan) error { +// Refine tries to build index range, bypass sort, set limit for source plan. +// It prepares the plan for cost estimation. +func Refine(p Plan) error { r := refiner{} p.Accept(&r) return r.err } -// refiner tries to build index range, bypass sort, set limit for source plan. -// It prepares the plan for cost estimation. type refiner struct { conditions []ast.ExprNode - // store index scan plan for sort to use. + // store scan plan for sort to use. indexScan *IndexScan + tableScan *TableScan err error } @@ -40,12 +42,16 @@ func (r *refiner) Enter(in Plan) (Plan, bool) { r.conditions = x.Conditions case *IndexScan: r.indexScan = x + case *TableScan: + r.tableScan = x } return in, false } func (r *refiner) Leave(in Plan) (Plan, bool) { switch x := in.(type) { + case *TableScan: + r.buildTableRange(x) case *IndexScan: r.buildIndexRange(x) case *Sort: @@ -85,6 +91,26 @@ func (r *refiner) sortBypass(p *Sort) { return } p.Bypass = true + } else if r.tableScan != nil { + if len(p.ByItems) != 1 { + return + } + byItem := p.ByItems[0] + if byItem.Desc { + // TODO: support desc when table reverse iterator is supported. + return + } + cn, ok := byItem.Expr.(*ast.ColumnNameExpr) + if !ok { + return + } + if !mysql.HasPriKeyFlag(cn.Refer.Column.Flag) { + return + } + if !cn.Refer.Table.PKIsHandle { + return + } + p.Bypass = true } } @@ -122,12 +148,35 @@ func (r *refiner) buildIndexRange(p *IndexScan) { return } +func (r *refiner) buildTableRange(p *TableScan) { + var pkHandleColumn *model.ColumnInfo + for _, colInfo := range p.Table.Columns { + if mysql.HasPriKeyFlag(colInfo.Flag) && p.Table.PKIsHandle { + pkHandleColumn = colInfo + } + } + if pkHandleColumn == nil { + return + } + rb := rangeBuilder{} + rangePoints := fullRange + checker := conditionChecker{pkName: pkHandleColumn.Name, tableName: p.Table.Name} + for _, cond := range r.conditions { + if checker.check(cond) { + rangePoints = rb.intersection(rangePoints, rb.build(cond)) + } + } + p.Ranges = rb.buildTableRanges(rangePoints) + r.err = rb.err +} + // conditionChecker checks if this condition can be pushed to index plan. type conditionChecker struct { tableName model.CIStr idx *model.IndexInfo // the offset of the indexed column to be checked. columnOffset int + pkName model.CIStr } func (c *conditionChecker) check(condition ast.ExprNode) bool { @@ -204,8 +253,9 @@ func (c *conditionChecker) checkColumnExpr(expr ast.ExprNode) bool { if cn.Refer.Table.Name.L != c.tableName.L { return false } - if cn.Refer.Column.Name.L != c.idx.Columns[c.columnOffset].Name.L { - return false + if c.pkName.L != "" { + return c.pkName.L == cn.Refer.Column.Name.L } - return true + + return cn.Refer.Column.Name.L == c.idx.Columns[c.columnOffset].Name.L } diff --git a/optimizer/resolver.go b/optimizer/resolver.go index 2e4aab4548..90d4d53b37 100644 --- a/optimizer/resolver.go +++ b/optimizer/resolver.go @@ -23,12 +23,12 @@ import ( ) // ResolveName resolves table name and column name. -// It generates ResultFields for ResetSetNode and resolve ColumnNameExpr to a ResultField. +// It generates ResultFields for ResultSetNode and resolves ColumnNameExpr to a ResultField. func ResolveName(node ast.Node, info infoschema.InfoSchema, ctx context.Context) error { defaultSchema := db.GetCurrentSchema(ctx) resolver := nameResolver{Info: info, Ctx: ctx, DefaultSchema: model.NewCIStr(defaultSchema)} node.Accept(&resolver) - return resolver.Err + return errors.Trace(resolver.Err) } // nameResolver is the visitor to resolve table name and column name. @@ -37,7 +37,7 @@ func ResolveName(node ast.Node, info infoschema.InfoSchema, ctx context.Context) // available for following elements. // // During visiting, information are collected and stored in resolverContext. -// When we enter a subquery, a new resolverContext is pushed to the contextStack, so sub query +// When we enter a subquery, a new resolverContext is pushed to the contextStack, so subquery // information can overwrite outer query information. When we look up for a column reference, // we look up from top to bottom in the contextStack. type nameResolver struct { @@ -139,7 +139,7 @@ func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren b case *ast.ByItem: if _, ok := v.Expr.(*ast.ColumnNameExpr); !ok { // If ByItem is not a single column name expression, - // the resolving rule is different for order by clause. + // the resolving rule is different from order by clause. nr.currentContext().inByItemExpression = true } case *ast.InsertStmt: @@ -194,8 +194,7 @@ func (nr *nameResolver) Leave(inNode ast.Node) (node ast.Node, ok bool) { return inNode, nr.Err == nil } -// handleTableName looks up and set the schema information for table name -// and set result fields for table name. +// handleTableName looks up and sets the schema information and result fields for table name. func (nr *nameResolver) handleTableName(tn *ast.TableName) { if tn.Schema.L == "" { tn.Schema = nr.DefaultSchema @@ -282,7 +281,7 @@ func (nr *nameResolver) handleColumnName(cn *ast.ColumnNameExpr) { } // resolveColumnNameInContext looks up and sets ResultField for a column with the ctx. -func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast.ColumnNameExpr) (done bool) { +func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast.ColumnNameExpr) bool { if ctx.inTableRefs { // In TableRefsClause, column reference only in join on condition which is handled before. return false @@ -439,7 +438,7 @@ func (nr *nameResolver) resolveColumnInResultFields(cn *ast.ColumnNameExpr, rfs return false } -// handleFieldList expands wild card field and set fieldList in current context. +// handleFieldList expands wild card field and sets fieldList in current context. func (nr *nameResolver) handleFieldList(fieldList *ast.FieldList) { var resultFields []*ast.ResultField for _, v := range fieldList.Fields { @@ -460,7 +459,7 @@ func (nr *nameResolver) createResultFields(field *ast.SelectField) (rfs []*ast.R ctx := nr.currentContext() if field.WildCard != nil { if len(ctx.tables) == 0 { - nr.Err = errors.Errorf("No table used.") + nr.Err = errors.New("No table used.") return } if field.WildCard.Table.L == "" { diff --git a/optimizer/validator.go b/optimizer/validator.go index d7f2b3563d..e2f420e457 100644 --- a/optimizer/validator.go +++ b/optimizer/validator.go @@ -63,7 +63,7 @@ func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) { } // checkAllOneColumn checks that all expressions have one column. -// Expression may has more than one column when it is a rowExpr or +// Expression may have more than one column when it is a rowExpr or // a Subquery with more than one result fields. func (v *validator) checkAllOneColumn(exprs ...ast.ExprNode) { for _, expr := range exprs { @@ -113,7 +113,7 @@ func (v *validator) checkSameColumns(exprs ...ast.ExprNode) { } } -// checkFieldList checks there is only one '*" and each field has only one column, +// checkFieldList checks if there is only one '*' and each field has only one column. func (v *validator) checkFieldList(x *ast.FieldList) { var hasWildCard bool for _, val := range x.Fields { diff --git a/plan/plans/from.go b/plan/plans/from.go index af1d7d5923..20ee2662d9 100644 --- a/plan/plans/from.go +++ b/plan/plans/from.go @@ -18,6 +18,8 @@ package plans import ( + "math" + "github.com/juju/errors" "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" @@ -103,17 +105,28 @@ type TableDefaultPlan struct { T table.Table Fields []*field.ResultField iter kv.Iterator + + // for range scan. + rangeScan bool + spans []*indexSpan + seekKey kv.Key + cursor int + skipLowCmp bool } // Explain implements the plan.Plan Explain interface. func (r *TableDefaultPlan) Explain(w format.Formatter) { - w.Format("┌Iterate all rows of table %q\n└Output field names %v\n", r.T.TableName(), field.RFQNames(r.Fields)) + fmtStr := "┌Iterate all rows of table %q\n└Output field names %v\n" + if r.rangeScan { + fmtStr = "┌Range scan rows of table %q\n└Output field names %v\n" + } + w.Format(fmtStr, r.T.TableName(), field.RFQNames(r.Fields)) } func (r *TableDefaultPlan) filterBinOp(ctx context.Context, x *expression.BinaryOperation) (plan.Plan, bool, error) { ok, name, rval, err := x.IsIdentCompareVal() if err != nil { - return r, false, err + return r, false, errors.Trace(err) } if !ok { return r, false, nil @@ -134,23 +147,37 @@ func (r *TableDefaultPlan) filterBinOp(ctx context.Context, x *expression.Binary if c == nil { return nil, false, errors.Errorf("No such column: %s", cn) } + var seekVal interface{} + if seekVal, err = types.Convert(rval, &c.FieldType); err != nil { + return nil, false, errors.Trace(err) + } + spans := toSpans(x.Op, rval, seekVal) + if c.IsPKHandleColumn(r.T.Meta()) { + if r.rangeScan { + spans = filterSpans(r.spans, spans) + } + return &TableDefaultPlan{ + T: r.T, + Fields: r.Fields, + rangeScan: true, + spans: spans, + }, true, nil + } else if r.rangeScan { + // Already filtered on PK handle column, should not switch to index plan. + return r, false, nil + } ix := t.FindIndexByColName(cn) if ix == nil { // Column cn has no index. return r, false, nil } - - var seekVal interface{} - if seekVal, err = types.Convert(rval, &c.FieldType); err != nil { - return nil, false, err - } return &indexPlan{ src: t, col: c, unique: ix.Unique, idxName: ix.Name.O, idx: ix.X, - spans: toSpans(x.Op, rval, seekVal), + spans: spans, }, true, nil } @@ -160,6 +187,26 @@ func (r *TableDefaultPlan) filterIdent(ctx context.Context, x *expression.Ident, if x.L != v.Name.L { continue } + var spans []*indexSpan + if trueValue { + spans = toSpans(opcode.NE, 0, 0) + } else { + spans = toSpans(opcode.EQ, 0, 0) + } + + if v.IsPKHandleColumn(t.Meta()) { + if r.rangeScan { + spans = filterSpans(r.spans, spans) + } + return &TableDefaultPlan{ + T: r.T, + Fields: r.Fields, + rangeScan: true, + spans: spans, + }, true, nil + } else if r.rangeScan { + return r, false, nil + } xi := v.Offset if xi >= len(t.Indices()) { @@ -170,12 +217,6 @@ func (r *TableDefaultPlan) filterIdent(ctx context.Context, x *expression.Ident, if ix == nil { // Column cn has no index. return r, false, nil } - var spans []*indexSpan - if trueValue { - spans = toSpans(opcode.NE, 0, 0) - } else { - spans = toSpans(opcode.EQ, 0, 0) - } return &indexPlan{ src: t, col: v, @@ -202,17 +243,33 @@ func (r *TableDefaultPlan) filterIsNull(ctx context.Context, x *expression.IsNul cn := cns[0] t := r.T - ix := t.FindIndexByColName(cn) - if ix == nil { // Column cn has no index. + col := column.FindCol(t.Cols(), cn) + if col == nil { return r, false, nil } - col := column.FindCol(t.Cols(), cn) + if col.IsPKHandleColumn(t.Meta()) { + if x.Not { + // PK handle column never be null. + return r, false, nil + } + return &NullPlan{ + Fields: r.Fields, + }, true, nil + } else if r.rangeScan { + return r, false, nil + } + var spans []*indexSpan if x.Not { spans = toSpans(opcode.GE, minNotNullVal, nil) } else { spans = toSpans(opcode.EQ, nil, nil) } + + ix := t.FindIndexByColName(cn) + if ix == nil { // Column cn has no index. + return r, false, nil + } return &indexPlan{ src: t, col: col, @@ -270,6 +327,9 @@ func (r *TableDefaultPlan) GetFields() []*field.ResultField { // Next implements plan.Plan Next interface. func (r *TableDefaultPlan) Next(ctx context.Context) (row *plan.Row, err error) { + if r.rangeScan { + return r.rangeNext(ctx) + } if r.iter == nil { var txn kv.Transaction txn, err = ctx.GetTxn(false) @@ -320,11 +380,100 @@ func (r *TableDefaultPlan) Next(ctx context.Context) (row *plan.Row, err error) return } +func (r *TableDefaultPlan) rangeNext(ctx context.Context) (*plan.Row, error) { + for { + if r.cursor == len(r.spans) { + return nil, nil + } + span := r.spans[r.cursor] + if r.seekKey == nil { + seekVal := span.seekVal + var err error + r.seekKey, err = r.toSeekKey(seekVal) + if err != nil { + return nil, errors.Trace(err) + } + } + txn, err := ctx.GetTxn(false) + if err != nil { + return nil, errors.Trace(err) + } + if r.iter != nil { + r.iter.Close() + } + r.iter, err = txn.Seek(r.seekKey) + if err != nil { + return nil, types.EOFAsNil(err) + } + + if !r.iter.Valid() || !r.iter.Key().HasPrefix(r.T.RecordPrefix()) { + r.seekKey = nil + r.cursor++ + r.skipLowCmp = false + continue + } + rowKey := r.iter.Key() + handle, err := tables.DecodeRecordKeyHandle(rowKey) + if err != nil { + return nil, errors.Trace(err) + } + r.seekKey, err = r.toSeekKey(handle + 1) + if err != nil { + return nil, errors.Trace(err) + } + if !r.skipLowCmp { + cmp := indexCompare(handle, span.lowVal) + if cmp < 0 || (cmp == 0 && span.lowExclude) { + continue + } + r.skipLowCmp = true + } + cmp := indexCompare(handle, span.highVal) + if cmp > 0 || (cmp == 0 && span.highExclude) { + // This span has finished iteration. + // Move to the next span. + r.seekKey = nil + r.cursor++ + r.skipLowCmp = false + continue + } + row := &plan.Row{} + row.Data, err = r.T.Row(ctx, handle) + if err != nil { + return nil, errors.Trace(err) + } + // Put rowKey to the tail of record row + rke := &plan.RowKeyEntry{ + Tbl: r.T, + Key: string(rowKey), + } + row.RowKeys = append(row.RowKeys, rke) + return row, nil + } +} + +func (r *TableDefaultPlan) toSeekKey(seekVal interface{}) (kv.Key, error) { + var handle int64 + var err error + if seekVal == nil { + handle = math.MinInt64 + } else { + handle, err = types.ToInt64(seekVal) + if err != nil { + return nil, errors.Trace(err) + } + } + return tables.EncodeRecordKey(r.T.TableID(), handle, 0), nil +} + // Close implements plan.Plan Close interface. func (r *TableDefaultPlan) Close() error { if r.iter != nil { r.iter.Close() r.iter = nil } + r.seekKey = nil + r.cursor = 0 + r.skipLowCmp = false return nil } diff --git a/plan/plans/index.go b/plan/plans/index.go index f35e7e5eb2..c6f8f84250 100644 --- a/plan/plans/index.go +++ b/plan/plans/index.go @@ -201,7 +201,7 @@ func (r *indexPlan) Filter(ctx context.Context, expr expression.Expression) (pla case *expression.BinaryOperation: ok, name, val, err := x.IsIdentCompareVal() if err != nil { - return nil, false, err + return nil, false, errors.Trace(err) } if !ok { break diff --git a/plan/plans/info.go b/plan/plans/info.go index 925d9d2c10..eb19c7fd59 100644 --- a/plan/plans/info.go +++ b/plan/plans/info.go @@ -435,86 +435,119 @@ func (isp *InfoSchemaPlan) fetchTables(schemas []*model.DBInfo) { func (isp *InfoSchemaPlan) fetchColumns(schemas []*model.DBInfo) { for _, schema := range schemas { for _, table := range schema.Tables { - for i, col := range table.Columns { - colLen := col.Flen - if colLen == types.UnspecifiedLength { - colLen = mysql.GetDefaultFieldLength(col.Tp) - } - decimal := col.Decimal - if decimal == types.UnspecifiedLength { - decimal = 0 - } - columnType := col.FieldType.CompactStr() - columnDesc := column.NewColDesc(&column.Col{ColumnInfo: *col}) - var columnDefault interface{} - if columnDesc.DefaultValue != nil { - columnDefault = fmt.Sprintf("%v", columnDesc.DefaultValue) - } - record := []interface{}{ - catalogVal, // TABLE_CATALOG - schema.Name.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - col.Name.O, // COLUMN_NAME - i + 1, // ORIGINAL_POSITION - columnDefault, // COLUMN_DEFAULT - columnDesc.Null, // IS_NULLABLE - types.TypeToStr(col.Tp, col.Charset), // DATA_TYPE - colLen, // CHARACTER_MAXIMUM_LENGTH - colLen, // CHARACTOR_OCTET_LENGTH - decimal, // NUMERIC_PRECISION - 0, // NUMERIC_SCALE - 0, // DATETIME_PRECISION - col.Charset, // CHARACTER_SET_NAME - col.Collate, // COLLATION_NAME - columnType, // COLUMN_TYPE - columnDesc.Key, // COLUMN_KEY - columnDesc.Extra, // EXTRA - "select,insert,update,references", // PRIVILEGES - "", // COLUMN_COMMENT - } - isp.rows = append(isp.rows, &plan.Row{Data: record}) - } + isp.fetchColumnsInTable(schema, table) } } } +func (isp *InfoSchemaPlan) fetchColumnsInTable(schema *model.DBInfo, table *model.TableInfo) { + for i, col := range table.Columns { + colLen := col.Flen + if colLen == types.UnspecifiedLength { + colLen = mysql.GetDefaultFieldLength(col.Tp) + } + decimal := col.Decimal + if decimal == types.UnspecifiedLength { + decimal = 0 + } + columnType := col.FieldType.CompactStr() + columnDesc := column.NewColDesc(&column.Col{ColumnInfo: *col}) + var columnDefault interface{} + if columnDesc.DefaultValue != nil { + columnDefault = fmt.Sprintf("%v", columnDesc.DefaultValue) + } + record := []interface{}{ + catalogVal, // TABLE_CATALOG + schema.Name.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + col.Name.O, // COLUMN_NAME + i + 1, // ORIGINAL_POSITION + columnDefault, // COLUMN_DEFAULT + columnDesc.Null, // IS_NULLABLE + types.TypeToStr(col.Tp, col.Charset), // DATA_TYPE + colLen, // CHARACTER_MAXIMUM_LENGTH + colLen, // CHARACTOR_OCTET_LENGTH + decimal, // NUMERIC_PRECISION + 0, // NUMERIC_SCALE + 0, // DATETIME_PRECISION + col.Charset, // CHARACTER_SET_NAME + col.Collate, // COLLATION_NAME + columnType, // COLUMN_TYPE + columnDesc.Key, // COLUMN_KEY + columnDesc.Extra, // EXTRA + "select,insert,update,references", // PRIVILEGES + "", // COLUMN_COMMENT + } + isp.rows = append(isp.rows, &plan.Row{Data: record}) + } +} + func (isp *InfoSchemaPlan) fetchStatistics(is infoschema.InfoSchema, schemas []*model.DBInfo) { for _, schema := range schemas { for _, table := range schema.Tables { - for _, index := range table.Indices { - nonUnique := "1" - if index.Unique { - nonUnique = "0" - } - for i, key := range index.Columns { - col, _ := is.ColumnByName(schema.Name, table.Name, key.Name) - nullable := "YES" - if mysql.HasNotNullFlag(col.Flag) { - nullable = "" - } - record := []interface{}{ - catalogVal, // TABLE_CATALOG - schema.Name.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - nonUnique, // NON_UNIQUE - schema.Name.O, // INDEX_SCHEMA - index.Name.O, // INDEX_NAME - i + 1, // SEQ_IN_INDEX - key.Name.O, // COLUMN_NAME - "A", // COLLATION - 0, // CARDINALITY - nil, // SUB_PART - nil, // PACKED - nullable, // NULLABLE - "BTREE", // INDEX_TYPE - "", // COMMENT - "", // INDEX_COMMENT - } - isp.rows = append(isp.rows, &plan.Row{Data: record}) + isp.fetchStatisticsInTable(is, schema, table) + } + } +} + +func (isp *InfoSchemaPlan) fetchStatisticsInTable(is infoschema.InfoSchema, schema *model.DBInfo, table *model.TableInfo) { + if table.PKIsHandle { + for _, col := range table.Columns { + if mysql.HasPriKeyFlag(col.Flag) { + record := []interface{}{ + catalogVal, // TABLE_CATALOG + schema.Name.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + "0", // NON_UNIQUE + schema.Name.O, // INDEX_SCHEMA + "PRIMARY", // INDEX_NAME + 1, // SEQ_IN_INDEX + col.Name.O, // COLUMN_NAME + "A", // COLLATION + 0, // CARDINALITY + nil, // SUB_PART + nil, // PACKED + "", // NULLABLE + "BTREE", // INDEX_TYPE + "", // COMMENT + "", // INDEX_COMMENT } + isp.rows = append(isp.rows, &plan.Row{Data: record}) } } } + for _, index := range table.Indices { + nonUnique := "1" + if index.Unique { + nonUnique = "0" + } + for i, key := range index.Columns { + col, _ := is.ColumnByName(schema.Name, table.Name, key.Name) + nullable := "YES" + if mysql.HasNotNullFlag(col.Flag) { + nullable = "" + } + record := []interface{}{ + catalogVal, // TABLE_CATALOG + schema.Name.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + nonUnique, // NON_UNIQUE + schema.Name.O, // INDEX_SCHEMA + index.Name.O, // INDEX_NAME + i + 1, // SEQ_IN_INDEX + key.Name.O, // COLUMN_NAME + "A", // COLLATION + 0, // CARDINALITY + nil, // SUB_PART + nil, // PACKED + nullable, // NULLABLE + "BTREE", // INDEX_TYPE + "", // COMMENT + "", // INDEX_COMMENT + } + isp.rows = append(isp.rows, &plan.Row{Data: record}) + } + } } func (isp *InfoSchemaPlan) fetchCharacterSets() { diff --git a/session_test.go b/session_test.go index fa890d09e6..a680c4437d 100644 --- a/session_test.go +++ b/session_test.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/autocommit" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/terror" ) var _ = Suite(&testSessionSuite{}) @@ -611,6 +612,8 @@ func (s *testSessionSuite) TestSelect(c *C) { row, err = r.FirstRow() c.Assert(err, IsNil) match(c, row, 3) + + mustExecSQL(c, se, `select * from t1, t2 where t1.c1 is null`) } func (s *testSessionSuite) TestSubQuery(c *C) { @@ -1057,9 +1060,21 @@ func (s *testSessionSuite) TestIssue461(c *C) { mustExecSQL(c, se1, "insert into test(id, val) values(1, 1);") se2 := newSession(c, store, s.dbName) mustExecSQL(c, se2, "begin;") - mustExecSQL(c, se2, "insert into test(id, val) values(1, 1);") - mustExecSQL(c, se2, "commit;") - mustExecFailed(c, se1, "commit;") + mustExecSQL(c, se2, "insert into test(id, val) values(2, 2);") + se3 := newSession(c, store, s.dbName) + mustExecSQL(c, se3, "begin;") + mustExecSQL(c, se3, "insert into test(id, val) values(1, 2);") + mustExecSQL(c, se3, "commit;") + _, err := se1.Execute("commit") + c.Assert(err, NotNil) + // Check error type and error message + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + c.Assert(err.Error(), Equals, "[kv:3]Duplicate entry '1' for key 'PRIMARY'") + + _, err = se2.Execute("commit") + c.Assert(err, NotNil) + c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) + c.Assert(err.Error(), Equals, "[kv:3]Duplicate entry '2' for key 'val'") se := newSession(c, store, s.dbName) mustExecSQL(c, se, "drop table test;") diff --git a/stmt/stmts/delete_test.go b/stmt/stmts/delete_test.go index d49588a093..7c659579f2 100644 --- a/stmt/stmts/delete_test.go +++ b/stmt/stmts/delete_test.go @@ -114,7 +114,7 @@ func (s *testStmtSuite) TestDelete(c *C) { strs := s.queryStrings(s.testDB, `explain DELETE from test where id = 2;`, c) var useIndex bool for _, str := range strs { - if strings.Index(str, "index") > 0 { + if strings.Index(str, "Range") > 0 { useIndex = true } } diff --git a/stmt/stmts/select_test.go b/stmt/stmts/select_test.go index 87dce71217..b1a4d5cb28 100644 --- a/stmt/stmts/select_test.go +++ b/stmt/stmts/select_test.go @@ -46,7 +46,7 @@ func (s *testStmtSuite) TestSelectExplain(c *C) { strs := s.queryStrings(s.testDB, "explain select * from test where id = 1;", c) // Must use index - if strings.Index(strs[0], "index") < 0 { + if strings.Index(strs[0], "Range") < 0 { c.Fatalf("Should use index") } } diff --git a/stmt/stmts/update_test.go b/stmt/stmts/update_test.go index 404df96205..b07d84c914 100644 --- a/stmt/stmts/update_test.go +++ b/stmt/stmts/update_test.go @@ -58,7 +58,7 @@ func (s *testStmtSuite) TestUpdate(c *C) { strs := s.queryStrings(testDB, `explain `+updateStr, c) var useIndex bool for _, str := range strs { - if strings.Index(str, "index") > 0 { + if strings.Index(str, "Range") > 0 { useIndex = true } } diff --git a/store/hbase/hbase_test.go b/store/hbase/hbase_test.go index aafb535b15..e071c4c595 100644 --- a/store/hbase/hbase_test.go +++ b/store/hbase/hbase_test.go @@ -30,17 +30,17 @@ type testHBaseSuite struct { func (t *testHBaseSuite) TestParsePath(c *C) { tbl := []struct { - dsn string - ok bool - zks []string - oracle string - table string + dsn string + ok bool + zks string + tso string + table string }{ - {"hbase://z,k,zk/tbl", true, []string{"z", "k", "zk"}, "", "tbl"}, - {"hbase://z:80,k:80/tbl?tso=127.0.0.1:1234", true, []string{"z:80", "k:80"}, "127.0.0.1:1234", "tbl"}, - {"goleveldb://zk/tbl", false, nil, "", ""}, - {"hbase://zk/path/tbl", false, nil, "", ""}, - {"hbase:///zk/tbl", false, nil, "", ""}, + {"hbase://z,k,zk/tbl", true, "z,k,zk", tsoTypeLocal, "tbl"}, + {"hbase://z:80,k:80/tbl?tso=zk", true, "z:80,k:80", tsoTypeZK, "tbl"}, + {"goleveldb://zk/tbl", false, "", "", ""}, + {"hbase://zk/path/tbl", false, "", "", ""}, + {"hbase:///zk/tbl", false, "", "", ""}, } for _, t := range tbl { @@ -48,7 +48,7 @@ func (t *testHBaseSuite) TestParsePath(c *C) { if t.ok { c.Assert(err, IsNil, Commentf("dsn=%v", t.dsn)) c.Assert(zks, DeepEquals, t.zks, Commentf("dsn=%v", t.dsn)) - c.Assert(oracle, Equals, t.oracle, Commentf("dsn=%v", t.dsn)) + c.Assert(oracle, Equals, t.tso, Commentf("dsn=%v", t.dsn)) c.Assert(table, Equals, t.table, Commentf("dsn=%v", t.dsn)) } else { c.Assert(err, NotNil, Commentf("dsn=%v", t.dsn)) diff --git a/store/hbase/kv.go b/store/hbase/kv.go index 65153aa0db..6ceaff3050 100644 --- a/store/hbase/kv.go +++ b/store/hbase/kv.go @@ -69,12 +69,11 @@ func init() { } type hbaseStore struct { - mu sync.Mutex - uuid string - storeName string - oracleAddr string - oracle oracle.Oracle - conns []hbase.HBaseClient + mu sync.Mutex + uuid string + storeName string + oracle oracle.Oracle + conns []hbase.HBaseClient } func (s *hbaseStore) getHBaseClient() hbase.HBaseClient { @@ -139,26 +138,34 @@ func (s *hbaseStore) CurrentVersion() (kv.Version, error) { type Driver struct { } +const ( + tsoTypeLocal = "local" + tsoTypeZK = "zk" + + tsoZKPath = "/zk/tso" +) + // Open opens or creates an HBase storage with given path. // -// The format of path should be 'hbase://zk1,zk2,zk3/table[?tso=host:port]'. +// The format of path should be 'hbase://zk1,zk2,zk3/table[?tso=local|zk]'. // If tso is not provided, it will use a local oracle instead. (for test only) func (d Driver) Open(path string) (kv.Storage, error) { mc.mu.Lock() defer mc.mu.Unlock() - zks, oracleAddr, tableName, err := parsePath(path) + zks, tso, tableName, err := parsePath(path) if err != nil { return nil, errors.Trace(err) } + if tso != tsoTypeLocal && tso != tsoTypeZK { + return nil, errors.Trace(ErrInvalidDSN) + } uuid := fmt.Sprintf("hbase-%v-%v", zks, tableName) + if tso == tsoTypeLocal { + log.Warnf("hbase: store(%s) is using local oracle(for test only)", uuid) + } if store, ok := mc.cache[uuid]; ok { - if oracleAddr != store.oracleAddr { - err = errors.Errorf("hbase: store(%s) is opened with a different tso, old: %v, new: %v", uuid, store.oracleAddr, oracleAddr) - log.Warn(errors.ErrorStack(err)) - return nil, err - } return store, nil } @@ -167,7 +174,7 @@ func (d Driver) Open(path string) (kv.Storage, error) { conns := make([]hbase.HBaseClient, 0, hbaseConnPoolSize) for i := 0; i < hbaseConnPoolSize; i++ { var c hbase.HBaseClient - c, err = hbase.NewClient(zks, "/hbase") + c, err = hbase.NewClient(strings.Split(zks, ","), "/hbase") if err != nil { return nil, errors.Trace(err) } @@ -193,37 +200,39 @@ func (d Driver) Open(path string) (kv.Storage, error) { } var ora oracle.Oracle - if len(oracleAddr) == 0 { - log.Warnf("hbase: store(%s) is using local oracle(for test only)", uuid) + switch tso { + case tsoTypeLocal: ora = oracles.NewLocalOracle() - } else { - ora = oracles.NewRemoteOracle(oracleAddr) + case tsoTypeZK: + ora = oracles.NewRemoteOracle(zks, tsoZKPath) } s := &hbaseStore{ - uuid: uuid, - storeName: tableName, - oracleAddr: oracleAddr, - oracle: ora, - conns: conns, + uuid: uuid, + storeName: tableName, + oracle: ora, + conns: conns, } mc.cache[uuid] = s return s, nil } -func parsePath(path string) (zks []string, oracleAddr, tableName string, err error) { +func parsePath(path string) (zks, tso, tableName string, err error) { u, err := url.Parse(path) if err != nil { - return nil, "", "", errors.Trace(err) + return "", "", "", errors.Trace(err) } if strings.ToLower(u.Scheme) != "hbase" { - return nil, "", "", errors.Trace(ErrInvalidDSN) + return "", "", "", errors.Trace(ErrInvalidDSN) } p, tableName := filepath.Split(u.Path) if p != "/" { - return nil, "", "", errors.Trace(ErrInvalidDSN) + return "", "", "", errors.Trace(ErrInvalidDSN) } - zks = strings.Split(u.Host, ",") - oracleAddr = u.Query().Get("tso") - return zks, oracleAddr, tableName, nil + zks = u.Host + tso = u.Query().Get("tso") + if tso == "" { + tso = tsoTypeLocal + } + return zks, tso, tableName, nil } diff --git a/table/tables/tables.go b/table/tables/tables.go index eaaff24281..ecf13dca19 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -413,28 +413,10 @@ func (t *Table) AddRecord(ctx context.Context, r []interface{}) (recordID int64, bs := kv.NewBufferStore(txn) defer bs.Release() - for _, v := range t.indices { - if v == nil || v.State == model.StateDeleteOnly || v.State == model.StateDeleteReorganization { - // if index is in delete only or delete reorganization state, we can't add it. - continue - } - colVals, _ := v.FetchValues(r) - if err = v.X.Create(bs, colVals, recordID); err != nil { - if terror.ErrorEqual(err, kv.ErrKeyExists) { - // Get the duplicate row handle - // For insert on duplicate syntax, we should update the row - iter, _, err1 := v.X.Seek(bs, colVals) - if err1 != nil { - return 0, errors.Trace(err1) - } - _, h, err1 := iter.Next() - if err1 != nil { - return 0, errors.Trace(err1) - } - return h, errors.Trace(err) - } - return 0, errors.Trace(err) - } + // Insert new entries into indices. + h, err := t.addIndices(ctx, recordID, r, bs) + if err != nil { + return h, errors.Trace(err) } if err = t.LockRow(ctx, recordID); err != nil { @@ -446,7 +428,6 @@ func (t *Table) AddRecord(ctx context.Context, r []interface{}) (recordID int64, if col.IsPKHandleColumn(t.meta) { continue } - var value interface{} if col.State == model.StateWriteOnly || col.State == model.StateWriteReorganization { // if col is in write only or write reorganization state, we must add it with its default value. @@ -477,6 +458,82 @@ func (t *Table) AddRecord(ctx context.Context, r []interface{}) (recordID int64, return recordID, nil } +// Generate index content string representation. +func (t *Table) genIndexKeyStr(colVals []interface{}) (string, error) { + // Pass pre-composed error to txn. + strVals := make([]string, 0, len(colVals)) + for _, cv := range colVals { + cvs := "NULL" + var err error + if cv != nil { + cvs, err = types.ToString(cv) + if err != nil { + return "", errors.Trace(err) + } + } + strVals = append(strVals, cvs) + } + return strings.Join(strVals, "-"), nil +} + +// Add data into indices. +func (t *Table) addIndices(ctx context.Context, recordID int64, r []interface{}, bs *kv.BufferStore) (int64, error) { + txn, err := ctx.GetTxn(false) + if err != nil { + return 0, errors.Trace(err) + } + // Clean up lazy check error environment + defer txn.DelOption(kv.PresumeKeyNotExistsError) + if t.meta.PKIsHandle { + // Check key exists. + recordKey := t.RecordKey(recordID, nil) + e := kv.ErrKeyExists.Gen("Duplicate entry '%d' for key 'PRIMARY'", recordID) + txn.SetOption(kv.PresumeKeyNotExistsError, e) + _, err = txn.Get(recordKey) + if err == nil { + return recordID, errors.Trace(e) + } else if !terror.ErrorEqual(err, kv.ErrNotExist) { + return 0, errors.Trace(err) + } + txn.DelOption(kv.PresumeKeyNotExistsError) + } + + for _, v := range t.indices { + if v == nil || v.State == model.StateDeleteOnly || v.State == model.StateDeleteReorganization { + // if index is in delete only or delete reorganization state, we can't add it. + continue + } + colVals, _ := v.FetchValues(r) + var dupKeyErr error + if v.Unique || v.Primary { + entryKey, err1 := t.genIndexKeyStr(colVals) + if err1 != nil { + return 0, errors.Trace(err1) + } + dupKeyErr = kv.ErrKeyExists.Gen("Duplicate entry '%s' for key '%s'", entryKey, v.Name) + txn.SetOption(kv.PresumeKeyNotExistsError, dupKeyErr) + } + if err = v.X.Create(bs, colVals, recordID); err != nil { + if terror.ErrorEqual(err, kv.ErrKeyExists) { + // Get the duplicate row handle + // For insert on duplicate syntax, we should update the row + iter, _, err1 := v.X.Seek(bs, colVals) + if err1 != nil { + return 0, errors.Trace(err1) + } + _, h, err1 := iter.Next() + if err1 != nil { + return 0, errors.Trace(err1) + } + return h, errors.Trace(dupKeyErr) + } + return 0, errors.Trace(err) + } + txn.DelOption(kv.PresumeKeyNotExistsError) + } + return 0, nil +} + // EncodeValue implements table.Table EncodeValue interface. func (t *Table) EncodeValue(raw interface{}) ([]byte, error) { v, err := t.flatten(raw) diff --git a/util/testkit/testkit.go b/util/testkit/testkit.go index c08fa1215d..e7335d6a00 100644 --- a/util/testkit/testkit.go +++ b/util/testkit/testkit.go @@ -1,7 +1,22 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package testkit import ( "fmt" + "strings" + "github.com/pingcap/check" "github.com/pingcap/tidb" "github.com/pingcap/tidb/kv" @@ -84,3 +99,19 @@ func (res *Result) Check(expected [][]interface{}) { need := fmt.Sprintf("%v", expected) res.c.Assert(got, check.Equals, need, res.comment) } + +// Rows is a convenient function to wrap args to a slice of []interface. +// The arg represents a row, split by white space, only applicable for +// values that have no white spaces. +func Rows(args ...string) [][]interface{} { + rows := make([][]interface{}, len(args)) + for i, v := range args { + strs := strings.Split(v, " ") + row := make([]interface{}, len(strs)) + for j, s := range strs { + row[j] = s + } + rows[i] = row + } + return rows +}