Merge remote-tracking branch 'origin/master' into mvcc-gc

This commit is contained in:
dongxu
2015-10-17 15:05:03 +08:00
70 changed files with 5267 additions and 2499 deletions

81
ast/ast.go Normal file
View File

@ -0,0 +1,81 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ast is the abstract syntax tree parsed from a SQL statement by parser.
// It can be analysed and transformed by optimizer.
package ast
import (
"github.com/pingcap/tidb/util/types"
)
// Node is the basic element of the AST.
// Interfaces embed Node should have 'Node' name suffix.
type Node interface {
// Accept accepts Visitor to visit itself.
// The returned node should replace original node.
// ok returns false to stop visiting.
Accept(v Visitor) (node Node, ok bool)
// Text returns the original text of the element.
Text() string
// SetText sets original text to the Node.
SetText(text string)
}
// ExprNode is a node that can be evaluated.
// Name of implementations should have 'Expr' suffix.
type ExprNode interface {
// Node is embeded in ExprNode.
Node
// IsStatic means it can be evaluated independently.
IsStatic() bool
// SetType sets evaluation type to the expression.
SetType(tp *types.FieldType)
// GetType gets the evaluation type of the expression.
GetType() *types.FieldType
}
// FuncNode represents function call expression node.
type FuncNode interface {
ExprNode
functionExpression()
}
// StmtNode represents statement node.
// Name of implementations should have 'Stmt' suffix.
type StmtNode interface {
Node
statement()
}
// DDLNode represents DDL statement node.
type DDLNode interface {
StmtNode
ddlStatement()
}
// DMLNode represents DML statement node.
type DMLNode interface {
StmtNode
dmlStatement()
}
// Visitor visits a Node.
type Visitor interface {
// VisitEnter is called before children nodes is visited.
// ok returns false to stop visiting.
Enter(n Node) (ok bool)
// VisitLeave is called after children nodes has been visited.
// ok returns false to stop visiting.
Leave(n Node) (node Node, ok bool)
}

88
ast/base.go Normal file
View File

@ -0,0 +1,88 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import "github.com/pingcap/tidb/util/types"
// node is the struct implements node interface except for Accept method.
// Node implementations should embed it in.
type node struct {
text string
}
// SetText implements Node interface.
func (n *node) SetText(text string) {
n.text = text
}
// Text implements Node interface.
func (n *node) Text() string {
return n.text
}
// stmtNode implements StmtNode interface.
// Statement implementations should embed it in.
type stmtNode struct {
node
}
// statement implements StmtNode interface.
func (sn *stmtNode) statement() {}
// ddlNode implements DDLNode interface.
// DDL implementations should embed it in.
type ddlNode struct {
stmtNode
}
// ddlStatement implements DDLNode interface.
func (dn *ddlNode) ddlStatement() {}
// dmlNode is the struct implements DMLNode interface.
// DML implementations should embed it in.
type dmlNode struct {
stmtNode
}
// dmlStatement implements DMLNode interface.
func (dn *dmlNode) dmlStatement() {}
// expressionNode is the struct implements Expression interface.
// Expression implementations should embed it in.
type exprNode struct {
node
tp *types.FieldType
}
// IsStatic implements Expression interface.
func (en *exprNode) IsStatic() bool {
return false
}
// SetType implements Expression interface.
func (en *exprNode) SetType(tp *types.FieldType) {
en.tp = tp
}
// GetType implements Expression interface.
func (en *exprNode) GetType() *types.FieldType {
return en.tp
}
type funcNode struct {
exprNode
}
// FunctionExpression implements FounctionNode interface.
func (fn *funcNode) functionExpression() {}

525
ast/ddl.go Normal file
View File

@ -0,0 +1,525 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"github.com/pingcap/tidb/util/types"
)
var (
_ DDLNode = &CreateDatabaseStmt{}
_ DDLNode = &DropDatabaseStmt{}
_ DDLNode = &CreateTableStmt{}
_ DDLNode = &DropTableStmt{}
_ DDLNode = &CreateIndexStmt{}
_ DDLNode = &DropTableStmt{}
_ DDLNode = &AlterTableStmt{}
_ DDLNode = &TruncateTableStmt{}
_ Node = &IndexColName{}
_ Node = &ReferenceDef{}
_ Node = &ColumnOption{}
_ Node = &Constraint{}
_ Node = &ColumnDef{}
_ Node = &ColumnPosition{}
_ Node = &AlterTableSpec{}
)
// CharsetOpt is used for parsing charset option from SQL.
type CharsetOpt struct {
Chs string
Col string
}
// CreateDatabaseStmt is a statement to create a database.
// See: https://dev.mysql.com/doc/refman/5.7/en/create-database.html
type CreateDatabaseStmt struct {
ddlNode
IfNotExists bool
Name string
Opt *CharsetOpt
}
// Accept implements Node Accept interface.
func (cd *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(cd) {
return cd, false
}
return v.Leave(cd)
}
// DropDatabaseStmt is a statement to drop a database and all tables in the database.
// See: https://dev.mysql.com/doc/refman/5.7/en/drop-database.html
type DropDatabaseStmt struct {
ddlNode
IfExists bool
Name string
}
// Accept implements Node Accept interface.
func (dd *DropDatabaseStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(dd) {
return dd, false
}
return v.Leave(dd)
}
// IndexColName is used for parsing index column name from SQL.
type IndexColName struct {
node
Column *ColumnRefExpr
Length int
}
// Accept implements Node Accept interface.
func (ic *IndexColName) Accept(v Visitor) (Node, bool) {
if !v.Enter(ic) {
return ic, false
}
node, ok := ic.Column.Accept(v)
if !ok {
return ic, false
}
ic.Column = node.(*ColumnRefExpr)
return v.Leave(ic)
}
// ReferenceDef is used for parsing foreign key reference option from SQL.
// See: http://dev.mysql.com/doc/refman/5.7/en/create-table-foreign-keys.html
type ReferenceDef struct {
node
Table *TableRef
IndexColNames []*IndexColName
}
// Accept implements Node Accept interface.
func (rd *ReferenceDef) Accept(v Visitor) (Node, bool) {
if !v.Enter(rd) {
return rd, false
}
node, ok := rd.Table.Accept(v)
if !ok {
return rd, false
}
rd.Table = node.(*TableRef)
for i, val := range rd.IndexColNames {
node, ok = val.Accept(v)
if !ok {
return rd, false
}
rd.IndexColNames[i] = node.(*IndexColName)
}
return v.Leave(rd)
}
// ColumnOptionType is the type for ColumnOption.
type ColumnOptionType int
// ColumnOption types.
const (
ColumnOptionNoOption ColumnOptionType = iota
ColumnOptionPrimaryKey
ColumnOptionNotNull
ColumnOptionAutoIncrement
ColumnOptionDefaultValue
ColumnOptionUniq
ColumnOptionIndex
ColumnOptionUniqIndex
ColumnOptionKey
ColumnOptionUniqKey
ColumnOptionNull
ColumnOptionOnUpdate // For Timestamp and Datetime only.
ColumnOptionFulltext
ColumnOptionComment
)
// ColumnOption is used for parsing column constraint info from SQL.
type ColumnOption struct {
node
Tp ColumnOptionType
// The value For Default or On Update.
Val ExprNode
}
// Accept implements Node Accept interface.
func (co *ColumnOption) Accept(v Visitor) (Node, bool) {
if !v.Enter(co) {
return co, false
}
if co.Val != nil {
node, ok := co.Val.Accept(v)
if !ok {
return co, false
}
co.Val = node.(ExprNode)
}
return v.Leave(co)
}
// ConstraintType is the type for Constraint.
type ConstraintType int
// ConstraintTypes
const (
ConstraintNoConstraint ConstraintType = iota
ConstraintPrimaryKey
ConstraintKey
ConstraintIndex
ConstraintUniq
ConstraintUniqKey
ConstraintUniqIndex
ConstraintForeignKey
)
// Constraint is constraint for table definition.
type Constraint struct {
node
Tp ConstraintType
Name string
// Used for PRIMARY KEY, UNIQUE, ......
Keys []*IndexColName
// Used for foreign key.
Refer *ReferenceDef
}
// Accept implements Node Accept interface.
func (tc *Constraint) Accept(v Visitor) (Node, bool) {
if !v.Enter(tc) {
return tc, false
}
for i, val := range tc.Keys {
node, ok := val.Accept(v)
if !ok {
return tc, false
}
tc.Keys[i] = node.(*IndexColName)
}
if tc.Refer != nil {
node, ok := tc.Refer.Accept(v)
if !ok {
return tc, false
}
tc.Refer = node.(*ReferenceDef)
}
return v.Leave(tc)
}
// ColumnDef is used for parsing column definition from SQL.
type ColumnDef struct {
node
Name string
Tp *types.FieldType
Options []*ColumnOption
}
// Accept implements Node Accept interface.
func (cd *ColumnDef) Accept(v Visitor) (Node, bool) {
if !v.Enter(cd) {
return cd, false
}
for i, val := range cd.Options {
node, ok := val.Accept(v)
if !ok {
return cd, false
}
cd.Options[i] = node.(*ColumnOption)
}
return v.Leave(cd)
}
// CreateTableStmt is a statement to create a table.
// See: https://dev.mysql.com/doc/refman/5.7/en/create-table.html
type CreateTableStmt struct {
ddlNode
IfNotExists bool
Table *TableRef
Cols []*ColumnDef
Constraints []*Constraint
Options []*TableOption
}
// Accept implements Node Accept interface.
func (ct *CreateTableStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(ct) {
return ct, false
}
node, ok := ct.Table.Accept(v)
if !ok {
return ct, false
}
ct.Table = node.(*TableRef)
for i, val := range ct.Cols {
node, ok = val.Accept(v)
if !ok {
return ct, false
}
ct.Cols[i] = node.(*ColumnDef)
}
for i, val := range ct.Constraints {
node, ok = val.Accept(v)
if !ok {
return ct, false
}
ct.Constraints[i] = node.(*Constraint)
}
return v.Leave(ct)
}
// DropTableStmt is a statement to drop one or more tables.
// See: https://dev.mysql.com/doc/refman/5.7/en/drop-table.html
type DropTableStmt struct {
ddlNode
IfExists bool
TableRefs []*TableRef
}
// Accept implements Node Accept interface.
func (dt *DropTableStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(dt) {
return dt, false
}
for i, val := range dt.TableRefs {
node, ok := val.Accept(v)
if !ok {
return dt, false
}
dt.TableRefs[i] = node.(*TableRef)
}
return v.Leave(dt)
}
// CreateIndexStmt is a statement to create an index.
// See: https://dev.mysql.com/doc/refman/5.7/en/create-index.html
type CreateIndexStmt struct {
ddlNode
IndexName string
Table *TableRef
Unique bool
IndexColNames []*IndexColName
}
// Accept implements Node Accept interface.
func (ci *CreateIndexStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(ci) {
return ci, false
}
node, ok := ci.Table.Accept(v)
if !ok {
return ci, false
}
ci.Table = node.(*TableRef)
for i, val := range ci.IndexColNames {
node, ok = val.Accept(v)
if !ok {
return ci, false
}
ci.IndexColNames[i] = node.(*IndexColName)
}
return v.Leave(ci)
}
// DropIndexStmt is a statement to drop the index.
// See: https://dev.mysql.com/doc/refman/5.7/en/drop-index.html
type DropIndexStmt struct {
ddlNode
IfExists bool
IndexName string
}
// Accept implements Node Accept interface.
func (di *DropIndexStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(di) {
return di, false
}
return v.Leave(di)
}
// TableOptionType is the type for TableOption
type TableOptionType int
// TableOption types.
const (
TableOptionNone TableOptionType = iota
TableOptionEngine
TableOptionCharset
TableOptionCollate
TableOptionAutoIncrement
TableOptionComment
TableOptionAvgRowLength
TableOptionCheckSum
TableOptionCompression
TableOptionConnection
TableOptionPassword
TableOptionKeyBlockSize
TableOptionMaxRows
TableOptionMinRows
)
// TableOption is used for parsing table option from SQL.
type TableOption struct {
Tp TableOptionType
StrValue string
UintValue uint64
}
// ColumnPositionType is the type for ColumnPosition.
type ColumnPositionType int
// ColumnPosition Types
const (
ColumnPositionNone ColumnPositionType = iota
ColumnPositionFirst
ColumnPositionAfter
)
// ColumnPosition represent the position of the newly added column
type ColumnPosition struct {
node
// ColumnPositionNone | ColumnPositionFirst | ColumnPositionAfter
Tp ColumnPositionType
// RelativeColumn is the column the newly added column after if type is ColumnPositionAfter
RelativeColumn *ColumnRefExpr
}
// Accept implements Node Accept interface.
func (cp *ColumnPosition) Accept(v Visitor) (Node, bool) {
if !v.Enter(cp) {
return cp, false
}
node, ok := cp.RelativeColumn.Accept(v)
if !ok {
return cp, false
}
cp.RelativeColumn = node.(*ColumnRefExpr)
return v.Leave(cp)
}
// AlterTableType is the type for AlterTableSpec.
type AlterTableType int
// AlterTable types.
const (
AlterTableOption AlterTableType = iota + 1
AlterTableAddColumn
AlterTableAddConstraint
AlterTableDropColumn
AlterTableDropPrimaryKey
AlterTableDropIndex
AlterTableDropForeignKey
// TODO: Add more actions
)
// AlterTableSpec represents alter table specification.
type AlterTableSpec struct {
node
Tp AlterTableType
Name string
Constraint *Constraint
TableOpts []*TableOption
Column *ColumnDef
Position *ColumnPosition
}
// Accept implements Node Accept interface.
func (as *AlterTableSpec) Accept(v Visitor) (Node, bool) {
if !v.Enter(as) {
return as, false
}
if as.Constraint != nil {
node, ok := as.Constraint.Accept(v)
if !ok {
return as, false
}
as.Constraint = node.(*Constraint)
}
if as.Column != nil {
node, ok := as.Column.Accept(v)
if !ok {
return as, false
}
as.Column = node.(*ColumnDef)
}
if as.Position != nil {
node, ok := as.Position.Accept(v)
if !ok {
return as, false
}
as.Position = node.(*ColumnPosition)
}
return v.Leave(as)
}
// AlterTableStmt is a statement to change the structure of a table.
// See: https://dev.mysql.com/doc/refman/5.7/en/alter-table.html
type AlterTableStmt struct {
ddlNode
Table *TableRef
Specs []*AlterTableSpec
}
// Accept implements Node Accept interface.
func (at *AlterTableStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(at) {
return at, false
}
node, ok := at.Table.Accept(v)
if !ok {
return at, false
}
at.Table = node.(*TableRef)
for i, val := range at.Specs {
node, ok = val.Accept(v)
if !ok {
return at, false
}
at.Specs[i] = node.(*AlterTableSpec)
}
return v.Leave(at)
}
// TruncateTableStmt is a statement to empty a table completely.
// See: https://dev.mysql.com/doc/refman/5.7/en/truncate-table.html
type TruncateTableStmt struct {
ddlNode
Table *TableRef
}
// Accept implements Node Accept interface.
func (ts *TruncateTableStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(ts) {
return ts, false
}
node, ok := ts.Table.Accept(v)
if !ok {
return ts, false
}
ts.Table = node.(*TableRef)
return v.Leave(ts)
}

409
ast/dml.go Normal file
View File

@ -0,0 +1,409 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"github.com/pingcap/tidb/model"
)
var (
_ DMLNode = &InsertStmt{}
_ DMLNode = &DeleteStmt{}
_ DMLNode = &UpdateStmt{}
_ DMLNode = &SelectStmt{}
_ Node = &Join{}
_ Node = &Union{}
_ Node = &TableRef{}
_ Node = &TableSource{}
_ Node = &Assignment{}
)
// JoinType is join type, including cross/left/right/full.
type JoinType int
const (
// CrossJoin is cross join type.
CrossJoin JoinType = iota + 1
// LeftJoin is left Join type.
LeftJoin
// RightJoin is right Join type.
RightJoin
)
// Join represents table join.
type Join struct {
node
// Left table can be TableSource or JoinNode.
Left Node
// Right table can be TableSource or JoinNode or nil.
Right Node
// Tp represents join type.
Tp JoinType
}
// Accept implements Node Accept interface.
func (j *Join) Accept(v Visitor) (Node, bool) {
if !v.Enter(j) {
return j, false
}
node, ok := j.Left.Accept(v)
if !ok {
return j, false
}
j.Left = node
if j.Right != nil {
node, ok = j.Right.Accept(v)
if !ok {
return j, false
}
j.Right = node
}
return v.Leave(j)
}
// TableRef represents a reference to actual table.
type TableRef struct {
node
Schema model.CIStr
Name model.CIStr
}
// Accept implements Node Accept interface.
func (tr *TableRef) Accept(v Visitor) (Node, bool) {
if !v.Enter(tr) {
return tr, false
}
return v.Leave(tr)
}
// TableSource represents table source with a name.
type TableSource struct {
node
// Source is the source of the data, can be a TableRef,
// a SubQuery, or a JoinNode.
Source Node
// Name is the alias name of the table source.
Name string
}
// Accept implements Node Accept interface.
func (ts *TableSource) Accept(v Visitor) (Node, bool) {
if !v.Enter(ts) {
return ts, false
}
node, ok := ts.Source.Accept(v)
if !ok {
return ts, false
}
ts.Source = node
return v.Leave(ts)
}
// Union represents union select statement.
type Union struct {
node
Select *SelectStmt
}
// Accept implements Node Accept interface.
func (u *Union) Accept(v Visitor) (Node, bool) {
if !v.Enter(u) {
return u, false
}
node, ok := u.Select.Accept(v)
if !ok {
return u, false
}
u.Select = node.(*SelectStmt)
return v.Leave(u)
}
// SelectLockType is the lock type for SelectStmt.
type SelectLockType int
// Select lock types.
const (
SelectLockNone SelectLockType = iota
SelectLockForUpdate
SelectLockInShareMode
)
// SelectStmt represents the select query node.
type SelectStmt struct {
dmlNode
// Distinct represents if the select has distinct option.
Distinct bool
// Fields is the select expression list.
Fields []ExprNode
// From is the from clause of the query.
From *Join
// Where is the where clause in select statement.
Where ExprNode
// GroupBy is the group by expression list.
GroupBy []ExprNode
// Having is the having condition.
Having ExprNode
// OrderBy is the odering expression list.
OrderBy []ExprNode
// Offset is the offset value.
Offset int
// Limit is the limit value.
Limit int
// Lock is the lock type
LockTp SelectLockType
// Unions is the union select statement.
Unions []*Union
}
// Accept implements Node Accept interface.
func (sn *SelectStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(sn) {
return sn, false
}
for i, val := range sn.Fields {
node, ok := val.Accept(v)
if !ok {
return sn, false
}
sn.Fields[i] = node.(ExprNode)
}
if sn.From != nil {
node, ok := sn.From.Accept(v)
if !ok {
return sn, false
}
sn.From = node.(*Join)
}
if sn.Where != nil {
node, ok := sn.Where.Accept(v)
if !ok {
return sn, false
}
sn.Where = node.(ExprNode)
}
for i, val := range sn.GroupBy {
node, ok := val.Accept(v)
if !ok {
return sn, false
}
sn.GroupBy[i] = node.(ExprNode)
}
if sn.Having != nil {
node, ok := sn.Having.Accept(v)
if !ok {
return sn, false
}
sn.Having = node.(ExprNode)
}
for i, val := range sn.OrderBy {
node, ok := val.Accept(v)
if !ok {
return sn, false
}
sn.OrderBy[i] = node.(ExprNode)
}
for i, val := range sn.Unions {
node, ok := val.Accept(v)
if !ok {
return sn, false
}
sn.Unions[i] = node.(*Union)
}
return v.Leave(sn)
}
// Assignment is the expression for assignment, like a = 1.
type Assignment struct {
node
// Column is the column reference to be assigned.
Column *ColumnRefExpr
// Expr is the expression assigning to ColName.
Expr ExprNode
}
// Accept implements Node Accept interface.
func (as *Assignment) Accept(v Visitor) (Node, bool) {
if !v.Enter(as) {
return as, false
}
node, ok := as.Column.Accept(v)
if !ok {
return as, false
}
as.Column = node.(*ColumnRefExpr)
node, ok = as.Expr.Accept(v)
if !ok {
return as, false
}
as.Expr = node.(ExprNode)
return v.Leave(as)
}
// InsertStmt is a statement to insert new rows into an existing table.
// See: https://dev.mysql.com/doc/refman/5.7/en/insert.html
type InsertStmt struct {
dmlNode
Columns []*ColumnRefExpr
Lists [][]ExprNode
Table *TableRef
Setlist []*Assignment
Priority int
OnDuplicate []*Assignment
}
// Accept implements Node Accept interface.
func (in *InsertStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(in) {
return in, false
}
for i, val := range in.Columns {
node, ok := val.Accept(v)
if !ok {
return in, false
}
in.Columns[i] = node.(*ColumnRefExpr)
}
for i, list := range in.Lists {
for j, val := range list {
node, ok := val.Accept(v)
if !ok {
return in, false
}
in.Lists[i][j] = node.(ExprNode)
}
}
for i, val := range in.Setlist {
node, ok := val.Accept(v)
if !ok {
return in, false
}
in.Setlist[i] = node.(*Assignment)
}
for i, val := range in.OnDuplicate {
node, ok := val.Accept(v)
if !ok {
return in, false
}
in.OnDuplicate[i] = node.(*Assignment)
}
return v.Leave(in)
}
// DeleteStmt is a statement to delete rows from table.
// See: https://dev.mysql.com/doc/refman/5.7/en/delete.html
type DeleteStmt struct {
dmlNode
Tables []*TableRef
Where ExprNode
Order []ExprNode
Limit int
LowPriority bool
Ignore bool
Quick bool
MultiTable bool
BeforeFrom bool
}
// Accept implements Node Accept interface.
func (de *DeleteStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(de) {
return de, false
}
for i, val := range de.Tables {
node, ok := val.Accept(v)
if !ok {
return de, false
}
de.Tables[i] = node.(*TableRef)
}
if de.Where != nil {
node, ok := de.Where.Accept(v)
if !ok {
return de, false
}
de.Where = node.(ExprNode)
}
for i, val := range de.Order {
node, ok := val.Accept(v)
if !ok {
return de, false
}
de.Order[i] = node.(ExprNode)
}
return v.Leave(de)
}
// UpdateStmt is a statement to update columns of existing rows in tables with new values.
// See: https://dev.mysql.com/doc/refman/5.7/en/update.html
type UpdateStmt struct {
dmlNode
TableRefs *Join
List []*Assignment
Where ExprNode
Order []ExprNode
Limit int
LowPriority bool
Ignore bool
MultipleTable bool
}
// Accept implements Node Accept interface.
func (up *UpdateStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(up) {
return up, false
}
node, ok := up.TableRefs.Accept(v)
if !ok {
return up, false
}
up.TableRefs = node.(*Join)
for i, val := range up.List {
node, ok = val.Accept(v)
if !ok {
return up, false
}
up.List[i] = node.(*Assignment)
}
if up.Where != nil {
node, ok = up.Where.Accept(v)
if !ok {
return up, false
}
up.Where = node.(ExprNode)
}
for i, val := range up.Order {
node, ok = val.Accept(v)
if !ok {
return up, false
}
up.Order[i] = node.(ExprNode)
}
return v.Leave(up)
}

669
ast/expressions.go Normal file
View File

@ -0,0 +1,669 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/parser/opcode"
)
var (
_ ExprNode = &ValueExpr{}
_ ExprNode = &BetweenExpr{}
_ ExprNode = &BinaryOperationExpr{}
_ Node = &WhenClause{}
_ ExprNode = &CaseExpr{}
_ ExprNode = &SubqueryExpr{}
_ ExprNode = &CompareSubqueryExpr{}
_ ExprNode = &ColumnRefExpr{}
_ ExprNode = &DefaultExpr{}
_ ExprNode = &ExistsSubqueryExpr{}
_ ExprNode = &PatternInExpr{}
_ ExprNode = &IsNullExpr{}
_ ExprNode = &IsTruthExpr{}
_ ExprNode = &PatternLikeExpr{}
_ ExprNode = &ParamMarkerExpr{}
_ ExprNode = &ParenthesesExpr{}
_ ExprNode = &PositionExpr{}
_ ExprNode = &PatternRegexpExpr{}
_ ExprNode = &RowExpr{}
_ ExprNode = &UnaryOperationExpr{}
_ ExprNode = &ValuesExpr{}
_ ExprNode = &VariableExpr{}
)
// ValueExpr is the simple value expression.
type ValueExpr struct {
exprNode
// Val is the literal value.
Val interface{}
}
// IsStatic implements ExprNode interface.
func (val *ValueExpr) IsStatic() bool {
return true
}
// Accept implements Node interface.
func (val *ValueExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(val) {
return val, false
}
return v.Leave(val)
}
// BetweenExpr is for "between and" or "not between and" expression.
type BetweenExpr struct {
exprNode
// Expr is the expression to be checked.
Expr ExprNode
// Left is the expression for minimal value in the range.
Left ExprNode
// Right is the expression for maximum value in the range.
Right ExprNode
// Not is true, the expression is "not between and".
Not bool
}
// Accept implements Node interface.
func (b *BetweenExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(b) {
return b, false
}
node, ok := b.Expr.Accept(v)
if !ok {
return b, false
}
b.Expr = node.(ExprNode)
node, ok = b.Left.Accept(v)
if !ok {
return b, false
}
b.Left = node.(ExprNode)
node, ok = b.Right.Accept(v)
if !ok {
return b, false
}
b.Right = node.(ExprNode)
return v.Leave(b)
}
// IsStatic implements the ExprNode IsStatic interface.
func (b *BetweenExpr) IsStatic() bool {
return b.Expr.IsStatic() && b.Left.IsStatic() && b.Right.IsStatic()
}
// BinaryOperationExpr is for binary operation like 1 + 1, 1 - 1, etc.
type BinaryOperationExpr struct {
exprNode
// Op is the operator code for BinaryOperation.
Op opcode.Op
// L is the left expression in BinaryOperation.
L ExprNode
// R is the right expression in BinaryOperation.
R ExprNode
}
// Accept implements Node interface.
func (o *BinaryOperationExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(o) {
return o, false
}
node, ok := o.L.Accept(v)
if !ok {
return o, false
}
o.L = node.(ExprNode)
node, ok = o.R.Accept(v)
if !ok {
return o, false
}
o.R = node.(ExprNode)
return v.Leave(o)
}
// IsStatic implements the ExprNode IsStatic interface.
func (o *BinaryOperationExpr) IsStatic() bool {
return o.L.IsStatic() && o.R.IsStatic()
}
// WhenClause is the when clause in Case expression for "when condition then result".
type WhenClause struct {
node
// Expr is the condition expression in WhenClause.
Expr ExprNode
// Result is the result expression in WhenClause.
Result ExprNode
}
// Accept implements Node Accept interface.
func (w *WhenClause) Accept(v Visitor) (Node, bool) {
if !v.Enter(w) {
return w, false
}
node, ok := w.Expr.Accept(v)
if !ok {
return w, false
}
w.Expr = node.(ExprNode)
node, ok = w.Result.Accept(v)
if !ok {
return w, false
}
w.Result = node.(ExprNode)
return v.Leave(w)
}
// IsStatic implements the ExprNode IsStatic interface.
func (w *WhenClause) IsStatic() bool {
return w.Expr.IsStatic() && w.Result.IsStatic()
}
// CaseExpr is the case expression.
type CaseExpr struct {
exprNode
// Value is the compare value expression.
Value ExprNode
// WhenClauses is the condition check expression.
WhenClauses []*WhenClause
// ElseClause is the else result expression.
ElseClause ExprNode
}
// Accept implements Node Accept interface.
func (f *CaseExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(f) {
return f, false
}
if f.Value != nil {
node, ok := f.Value.Accept(v)
if !ok {
return f, false
}
f.Value = node.(ExprNode)
}
for i, val := range f.WhenClauses {
node, ok := val.Accept(v)
if !ok {
return f, false
}
f.WhenClauses[i] = node.(*WhenClause)
}
if f.ElseClause != nil {
node, ok := f.ElseClause.Accept(v)
if !ok {
return f, false
}
f.ElseClause = node.(ExprNode)
}
return v.Leave(f)
}
// IsStatic implements the ExprNode IsStatic interface.
func (f *CaseExpr) IsStatic() bool {
if f.Value != nil && !f.Value.IsStatic() {
return false
}
for _, w := range f.WhenClauses {
if !w.IsStatic() {
return false
}
}
if f.ElseClause != nil && !f.ElseClause.IsStatic() {
return false
}
return true
}
// SubqueryExpr represents a sub query.
type SubqueryExpr struct {
exprNode
// Query is the query SelectNode.
Query *SelectStmt
}
// Accept implements Node Accept interface.
func (sq *SubqueryExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(sq) {
return sq, false
}
node, ok := sq.Query.Accept(v)
if !ok {
return sq, false
}
sq.Query = node.(*SelectStmt)
return v.Leave(sq)
}
// CompareSubqueryExpr is the expression for "expr cmp (select ...)".
// See: https://dev.mysql.com/doc/refman/5.7/en/comparisons-using-subqueries.html
// See: https://dev.mysql.com/doc/refman/5.7/en/any-in-some-subqueries.html
// See: https://dev.mysql.com/doc/refman/5.7/en/all-subqueries.html
type CompareSubqueryExpr struct {
exprNode
// L is the left expression
L ExprNode
// Op is the comparison opcode.
Op opcode.Op
// R is the sub query for right expression.
R *SubqueryExpr
// All is true, we should compare all records in subquery.
All bool
}
// Accept implements Node Accept interface.
func (cs *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(cs) {
return cs, false
}
node, ok := cs.L.Accept(v)
if !ok {
return cs, false
}
cs.L = node.(ExprNode)
node, ok = cs.R.Accept(v)
if !ok {
return cs, false
}
cs.R = node.(*SubqueryExpr)
return v.Leave(cs)
}
// ColumnRefExpr represents a column reference.
type ColumnRefExpr struct {
exprNode
// Name is the referenced column name.
Name model.CIStr
}
// Accept implements Node Accept interface.
func (cr *ColumnRefExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(cr) {
return cr, false
}
return v.Leave(cr)
}
// DefaultExpr is the default expression using default value for a column.
type DefaultExpr struct {
exprNode
// Name is the column name.
Name string
}
// Accept implements Node Accept interface.
func (d *DefaultExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(d) {
return d, false
}
return v.Leave(d)
}
// ExistsSubqueryExpr is the expression for "exists (select ...)".
// https://dev.mysql.com/doc/refman/5.7/en/exists-and-not-exists-subqueries.html
type ExistsSubqueryExpr struct {
exprNode
// Sel is the sub query.
Sel *SubqueryExpr
}
// Accept implements Node Accept interface.
func (es *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(es) {
return es, false
}
node, ok := es.Sel.Accept(v)
if !ok {
return es, false
}
es.Sel = node.(*SubqueryExpr)
return v.Leave(es)
}
// PatternInExpr is the expression for in operator, like "expr in (1, 2, 3)" or "expr in (select c from t)".
type PatternInExpr struct {
exprNode
// Expr is the value expression to be compared.
Expr ExprNode
// List is the list expression in compare list.
List []ExprNode
// Not is true, the expression is "not in".
Not bool
// Sel is the sub query.
Sel *SubqueryExpr
}
// Accept implements Node Accept interface.
func (pi *PatternInExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(pi) {
return pi, false
}
node, ok := pi.Expr.Accept(v)
if !ok {
return pi, false
}
pi.Expr = node.(ExprNode)
for i, val := range pi.List {
node, ok = val.Accept(v)
if !ok {
return pi, false
}
pi.List[i] = node.(ExprNode)
}
if pi.Sel != nil {
node, ok = pi.Sel.Accept(v)
if !ok {
return pi, false
}
pi.Sel = node.(*SubqueryExpr)
}
return v.Leave(pi)
}
// IsNullExpr is the expression for null check.
type IsNullExpr struct {
exprNode
// Expr is the expression to be checked.
Expr ExprNode
// Not is true, the expression is "is not null".
Not bool
}
// Accept implements Node Accept interface.
func (is *IsNullExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(is) {
return is, false
}
node, ok := is.Expr.Accept(v)
if !ok {
return is, false
}
is.Expr = node.(ExprNode)
return v.Leave(is)
}
// IsStatic implements the ExprNode IsStatic interface.
func (is *IsNullExpr) IsStatic() bool {
return is.Expr.IsStatic()
}
// IsTruthExpr is the expression for true/false check.
type IsTruthExpr struct {
exprNode
// Expr is the expression to be checked.
Expr ExprNode
// Not is true, the expression is "is not true/false".
Not bool
// True indicates checking true or false.
True int64
}
// Accept implements Node Accept interface.
func (is *IsTruthExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(is) {
return is, false
}
node, ok := is.Expr.Accept(v)
if !ok {
return is, false
}
is.Expr = node.(ExprNode)
return v.Leave(is)
}
// IsStatic implements the ExprNode IsStatic interface.
func (is *IsTruthExpr) IsStatic() bool {
return is.Expr.IsStatic()
}
// PatternLikeExpr is the expression for like operator, e.g, expr like "%123%"
type PatternLikeExpr struct {
exprNode
// Expr is the expression to be checked.
Expr ExprNode
// Pattern is the like expression.
Pattern ExprNode
// Not is true, the expression is "not like".
Not bool
}
// Accept implements Node Accept interface.
func (pl *PatternLikeExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(pl) {
return pl, false
}
node, ok := pl.Expr.Accept(v)
if !ok {
return pl, false
}
pl.Expr = node.(ExprNode)
node, ok = pl.Pattern.Accept(v)
if !ok {
return pl, false
}
pl.Pattern = node.(ExprNode)
return v.Leave(pl)
}
// IsStatic implements the ExprNode IsStatic interface.
func (pl *PatternLikeExpr) IsStatic() bool {
return pl.Expr.IsStatic() && pl.Pattern.IsStatic()
}
// ParamMarkerExpr expresion holds a place for another expression.
// Used in parsing prepare statement.
type ParamMarkerExpr struct {
exprNode
}
// Accept implements Node Accept interface.
func (pm *ParamMarkerExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(pm) {
return pm, false
}
return v.Leave(pm)
}
// ParenthesesExpr is the parentheses expression.
type ParenthesesExpr struct {
exprNode
// Expr is the expression in parentheses.
Expr ExprNode
}
// Accept implements Node Accept interface.
func (p *ParenthesesExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(p) {
return p, false
}
if p.Expr != nil {
node, ok := p.Expr.Accept(v)
if !ok {
return p, false
}
p.Expr = node.(ExprNode)
}
return v.Leave(p)
}
// IsStatic implements the ExprNode IsStatic interface.
func (p *ParenthesesExpr) IsStatic() bool {
return p.Expr.IsStatic()
}
// PositionExpr is the expression for order by and group by position.
// MySQL use position expression started from 1, it looks a little confused inner.
// maybe later we will use 0 at first.
type PositionExpr struct {
exprNode
// N is the position, started from 1 now.
N int
// Name is the corresponding field name if we want better format and explain instead of position.
Name string
}
// IsStatic implements the ExprNode IsStatic interface.
func (p *PositionExpr) IsStatic() bool {
return true
}
// Accept implements Node Accept interface.
func (p *PositionExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(p) {
return p, false
}
return v.Leave(p)
}
// PatternRegexpExpr is the pattern expression for pattern match.
type PatternRegexpExpr struct {
exprNode
// Expr is the expression to be checked.
Expr ExprNode
// Pattern is the expression for pattern.
Pattern ExprNode
// Not is true, the expression is "not rlike",
Not bool
}
// Accept implements Node Accept interface.
func (p *PatternRegexpExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(p) {
return p, false
}
node, ok := p.Expr.Accept(v)
if !ok {
return p, false
}
p.Expr = node.(ExprNode)
node, ok = p.Pattern.Accept(v)
if !ok {
return p, false
}
p.Pattern = node.(ExprNode)
return v.Leave(p)
}
// IsStatic implements the ExprNode IsStatic interface.
func (p *PatternRegexpExpr) IsStatic() bool {
return p.Expr.IsStatic() && p.Pattern.IsStatic()
}
// RowExpr is the expression for row constructor.
// See https://dev.mysql.com/doc/refman/5.7/en/row-subqueries.html
type RowExpr struct {
exprNode
Values []ExprNode
}
// Accept implements Node Accept interface.
func (r *RowExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(r) {
return r, false
}
for i, val := range r.Values {
node, ok := val.Accept(v)
if !ok {
return r, false
}
r.Values[i] = node.(ExprNode)
}
return v.Leave(r)
}
// IsStatic implements the ExprNode IsStatic interface.
func (r *RowExpr) IsStatic() bool {
for _, v := range r.Values {
if !v.IsStatic() {
return false
}
}
return true
}
// UnaryOperationExpr is the expression for unary operator.
type UnaryOperationExpr struct {
exprNode
// Op is the operator opcode.
Op opcode.Op
// V is the unary expression.
V ExprNode
}
// Accept implements Node Accept interface.
func (u *UnaryOperationExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(u) {
return u, false
}
node, ok := u.V.Accept(v)
if !ok {
return u, false
}
u.V = node.(ExprNode)
return v.Leave(u)
}
// IsStatic implements the ExprNode IsStatic interface.
func (u *UnaryOperationExpr) IsStatic() bool {
return u.V.IsStatic()
}
// ValuesExpr is the expression used in INSERT VALUES
type ValuesExpr struct {
exprNode
// model.CIStr is column name.
Column *ColumnRefExpr
}
// Accept implements Node Accept interface.
func (va *ValuesExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(va) {
return va, false
}
node, ok := va.Column.Accept(v)
if !ok {
return va, false
}
va.Column = node.(*ColumnRefExpr)
return v.Leave(va)
}
// VariableExpr is the expression for variable.
type VariableExpr struct {
exprNode
// Name is the variable name.
Name string
// IsGlobal indicates whether this variable is global.
IsGlobal bool
// IsSystem indicates whether this variable is a global variable in current session.
IsSystem bool
}
// Accept implements Node Accept interface.
func (va *VariableExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(va) {
return va, false
}
return v.Leave(va)
}

251
ast/functions.go Normal file
View File

@ -0,0 +1,251 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"strings"
"github.com/pingcap/tidb/expression/builtin"
"github.com/pingcap/tidb/util/types"
)
var (
_ FuncNode = &FuncCallExpr{}
_ FuncNode = &FuncExtractExpr{}
_ FuncNode = &FuncConvertExpr{}
_ FuncNode = &FuncCastExpr{}
_ FuncNode = &FuncSubstringExpr{}
_ FuncNode = &FuncTrimExpr{}
)
// FuncCallExpr is for function expression.
type FuncCallExpr struct {
funcNode
// F is the function name.
F string
// Args is the function args.
Args []ExprNode
// Distinct only affetcts sum, avg, count, group_concat,
// so we can ignore it in other functions
Distinct bool
}
// Accept implements Node interface.
func (c *FuncCallExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(c) {
return c, false
}
for i, val := range c.Args {
node, ok := val.Accept(v)
if !ok {
return c, false
}
c.Args[i] = node.(ExprNode)
}
return v.Leave(c)
}
// IsStatic implements the ExprNode IsStatic interface.
func (c *FuncCallExpr) IsStatic() bool {
v := builtin.Funcs[strings.ToLower(c.F)]
if v.F == nil || !v.IsStatic {
return false
}
for _, v := range c.Args {
if !v.IsStatic() {
return false
}
}
return true
}
// FuncExtractExpr is for time extract function.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract
type FuncExtractExpr struct {
funcNode
Unit string
Date ExprNode
}
// Accept implements Node Accept interface.
func (ex *FuncExtractExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(ex) {
return ex, false
}
node, ok := ex.Date.Accept(v)
if !ok {
return ex, false
}
ex.Date = node.(ExprNode)
return v.Leave(ex)
}
// IsStatic implements the ExprNode IsStatic interface.
func (ex *FuncExtractExpr) IsStatic() bool {
return ex.Date.IsStatic()
}
// FuncConvertExpr provides a way to convert data between different character sets.
// See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert
type FuncConvertExpr struct {
funcNode
// Expr is the expression to be converted.
Expr ExprNode
// Charset is the target character set to convert.
Charset string
}
// IsStatic implements the ExprNode IsStatic interface.
func (f *FuncConvertExpr) IsStatic() bool {
return f.Expr.IsStatic()
}
// Accept implements Node Accept interface.
func (f *FuncConvertExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(f) {
return f, false
}
node, ok := f.Expr.Accept(v)
if !ok {
return f, false
}
f.Expr = node.(ExprNode)
return v.Leave(f)
}
// castOperatopr is the operator type for cast function.
type castFunctionType int
// castFunction types
const (
CastFunction castFunctionType = iota + 1
ConvertFunction
BinaryOperator
)
// FuncCastExpr is the cast function converting value to another type, e.g, cast(expr AS signed).
// See https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html
type FuncCastExpr struct {
funcNode
// Expr is the expression to be converted.
Expr ExprNode
// Tp is the conversion type.
Tp *types.FieldType
// Cast, Convert and Binary share this struct.
FunctionType castFunctionType
}
// IsStatic implements the ExprNode IsStatic interface.
func (f *FuncCastExpr) IsStatic() bool {
return f.Expr.IsStatic()
}
// Accept implements Node Accept interface.
func (f *FuncCastExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(f) {
return f, false
}
node, ok := f.Expr.Accept(v)
if !ok {
return f, false
}
f.Expr = node.(ExprNode)
return v.Leave(f)
}
// FuncSubstringExpr returns the substring as specified.
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_substring
type FuncSubstringExpr struct {
funcNode
StrExpr ExprNode
Pos ExprNode
Len ExprNode
}
// Accept implements Node Accept interface.
func (sf *FuncSubstringExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(sf) {
return sf, false
}
node, ok := sf.StrExpr.Accept(v)
if !ok {
return sf, false
}
sf.StrExpr = node.(ExprNode)
node, ok = sf.Pos.Accept(v)
if !ok {
return sf, false
}
sf.Pos = node.(ExprNode)
node, ok = sf.Len.Accept(v)
if !ok {
return sf, false
}
sf.Len = node.(ExprNode)
return v.Leave(sf)
}
// IsStatic implements the ExprNode IsStatic interface.
func (sf *FuncSubstringExpr) IsStatic() bool {
return sf.StrExpr.IsStatic() && sf.Pos.IsStatic() && sf.Len.IsStatic()
}
type trimDirectionType int
const (
// TrimBothDefault trims from both direction by default.
TrimBothDefault trimDirectionType = iota
// TrimBoth trims from both direction with explicit notation.
TrimBoth
// TrimLeading trims from left.
TrimLeading
// TrimTrailing trims from right.
TrimTrailing
)
// FuncTrimExpr remove leading/trailing/both remstr.
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim
type FuncTrimExpr struct {
funcNode
Str ExprNode
RemStr ExprNode
Direction trimDirectionType
}
// Accept implements Node Accept interface.
func (tf *FuncTrimExpr) Accept(v Visitor) (Node, bool) {
if !v.Enter(tf) {
return tf, false
}
node, ok := tf.Str.Accept(v)
if !ok {
return tf, false
}
tf.Str = node.(ExprNode)
node, ok = tf.RemStr.Accept(v)
if !ok {
return tf, false
}
tf.RemStr = node.(ExprNode)
return v.Leave(tf)
}
// IsStatic implements the ExprNode IsStatic interface.
func (tf *FuncTrimExpr) IsStatic() bool {
return tf.Str.IsStatic() && tf.RemStr.IsStatic()
}

283
ast/misc.go Normal file
View File

@ -0,0 +1,283 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
var (
_ StmtNode = &ExplainStmt{}
_ StmtNode = &PrepareStmt{}
_ StmtNode = &DeallocateStmt{}
_ StmtNode = &ExecuteStmt{}
_ StmtNode = &ShowStmt{}
_ StmtNode = &BeginStmt{}
_ StmtNode = &CommitStmt{}
_ StmtNode = &RollbackStmt{}
_ StmtNode = &UseStmt{}
_ StmtNode = &SetStmt{}
_ Node = &VariableAssignment{}
)
// AuthOption is used for parsing create use statement.
type AuthOption struct {
// AuthString/HashString can be empty, so we need to decide which one to use.
ByAuthString bool
AuthString string
HashString string
// TODO: support auth_plugin
}
// ExplainStmt is a statement to provide information about how is SQL statement executed
// or get columns information in a table.
// See: https://dev.mysql.com/doc/refman/5.7/en/explain.html
type ExplainStmt struct {
stmtNode
Stmt DMLNode
}
// Accept implements Node Accept interface.
func (es *ExplainStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(es) {
return es, false
}
node, ok := es.Stmt.Accept(v)
if !ok {
return es, false
}
es.Stmt = node.(DMLNode)
return v.Leave(es)
}
// PrepareStmt is a statement to prepares a SQL statement which contains placeholders,
// and it is executed with ExecuteStmt and released with DeallocateStmt.
// See: https://dev.mysql.com/doc/refman/5.7/en/prepare.html
type PrepareStmt struct {
stmtNode
InPrepare bool // true for prepare mode, false for use mode
Name string
ID uint32 // For binary protocol, there is no Name but only ID
SQLStmt Node // The parsed statement from sql text with placeholder
}
// Accept implements Node Accept interface.
func (ps *PrepareStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(ps) {
return ps, false
}
node, ok := ps.SQLStmt.Accept(v)
if !ok {
return ps, false
}
ps.SQLStmt = node
return v.Leave(ps)
}
// DeallocateStmt is a statement to release PreparedStmt.
// See: https://dev.mysql.com/doc/refman/5.7/en/deallocate-prepare.html
type DeallocateStmt struct {
stmtNode
Name string
ID uint32 // For binary protocol, there is no Name but only ID.
}
// Accept implements Node Accept interface.
func (ds *DeallocateStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(ds) {
return ds, false
}
return v.Leave(ds)
}
// ExecuteStmt is a statement to execute PreparedStmt.
// See: https://dev.mysql.com/doc/refman/5.7/en/execute.html
type ExecuteStmt struct {
stmtNode
Name string
ID uint32 // For binary protocol, there is no Name but only ID
UsingVars []ExprNode
}
// Accept implements Node Accept interface.
func (es *ExecuteStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(es) {
return es, false
}
for i, val := range es.UsingVars {
node, ok := val.Accept(v)
if !ok {
return es, false
}
es.UsingVars[i] = node.(ExprNode)
}
return v.Leave(es)
}
// ShowStmt is a statement to provide information about databases, tables, columns and so on.
// See: https://dev.mysql.com/doc/refman/5.7/en/show.html
type ShowStmt struct {
stmtNode
Target int // Databases/Tables/Columns/....
DBName string
Table *TableRef // Used for showing columns.
Column *ColumnRefExpr // Used for `desc table column`.
Flag int // Some flag parsed from sql, such as FULL.
Full bool
// Used by show variables
GlobalScope bool
Pattern *PatternLikeExpr
Where ExprNode
}
// Accept implements Node Accept interface.
func (ss *ShowStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(ss) {
return ss, false
}
if ss.Table != nil {
node, ok := ss.Table.Accept(v)
if !ok {
return ss, false
}
ss.Table = node.(*TableRef)
}
if ss.Column != nil {
node, ok := ss.Column.Accept(v)
if !ok {
return ss, false
}
ss.Column = node.(*ColumnRefExpr)
}
if ss.Pattern != nil {
node, ok := ss.Pattern.Accept(v)
if !ok {
return ss, false
}
ss.Pattern = node.(*PatternLikeExpr)
}
if ss.Where != nil {
node, ok := ss.Where.Accept(v)
if !ok {
return ss, false
}
ss.Where = node.(ExprNode)
}
return v.Leave(ss)
}
// BeginStmt is a statement to start a new transaction.
// See: https://dev.mysql.com/doc/refman/5.7/en/commit.html
type BeginStmt struct {
stmtNode
}
// Accept implements Node Accept interface.
func (bs *BeginStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(bs) {
return bs, false
}
return v.Leave(bs)
}
// CommitStmt is a statement to commit the current transaction.
// See: https://dev.mysql.com/doc/refman/5.7/en/commit.html
type CommitStmt struct {
stmtNode
}
// Accept implements Node Accept interface.
func (cs *CommitStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(cs) {
return cs, false
}
return v.Leave(cs)
}
// RollbackStmt is a statement to roll back the current transaction.
// See: https://dev.mysql.com/doc/refman/5.7/en/commit.html
type RollbackStmt struct {
stmtNode
}
// Accept implements Node Accept interface.
func (rs *RollbackStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(rs) {
return rs, false
}
return v.Leave(rs)
}
// UseStmt is a statement to use the DBName database as the current database.
// See: https://dev.mysql.com/doc/refman/5.7/en/use.html
type UseStmt struct {
stmtNode
DBName string
}
// Accept implements Node Accept interface.
func (us *UseStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(us) {
return us, false
}
return v.Leave(us)
}
// VariableAssignment is a variable assignment struct.
type VariableAssignment struct {
node
Name string
Value ExprNode
IsGlobal bool
IsSystem bool
}
// Accept implements Node interface.
func (va *VariableAssignment) Accept(v Visitor) (Node, bool) {
if !v.Enter(va) {
return va, false
}
node, ok := va.Value.Accept(v)
if !ok {
return va, false
}
va.Value = node.(ExprNode)
return v.Leave(va)
}
// SetStmt is the statement to set variables.
type SetStmt struct {
stmtNode
// Variables is the list of variable assignment.
Variables []*VariableAssignment
}
// Accept implements Node Accept interface.
func (set *SetStmt) Accept(v Visitor) (Node, bool) {
if !v.Enter(set) {
return set, false
}
for i, val := range set.Variables {
node, ok := val.Accept(v)
if !ok {
return set, false
}
set.Variables[i] = node.(*VariableAssignment)
}
return v.Leave(set)
}

View File

@ -18,8 +18,6 @@
package column
import (
"bytes"
"fmt"
"strings"
"github.com/juju/errors"
@ -46,7 +44,7 @@ type IndexedCol struct {
// String implements fmt.Stringer interface.
func (c *Col) String() string {
ans := []string{c.Name.O, types.FieldTypeToStr(c.Tp, c.Charset)}
ans := []string{c.Name.O, types.TypeToStr(c.Tp, c.Charset)}
if mysql.HasAutoIncrementFlag(c.Flag) {
ans = append(ans, "AUTO_INCREMENT")
}
@ -121,41 +119,11 @@ const defaultPrivileges string = "select,insert,update,references"
// GetTypeDesc gets the description for column type.
func (c *Col) GetTypeDesc() string {
var buf bytes.Buffer
buf.WriteString(types.FieldTypeToStr(c.Tp, c.Charset))
switch c.Tp {
case mysql.TypeSet, mysql.TypeEnum:
// Format is ENUM ('e1', 'e2') or SET ('e1', 'e2')
// If elem contain ', we will convert ' -> ''
elems := make([]string, len(c.Elems))
for i := range elems {
elems[i] = strings.Replace(c.Elems[i], "'", "''", -1)
}
buf.WriteString(fmt.Sprintf("('%s')", strings.Join(elems, "','")))
case mysql.TypeFloat, mysql.TypeDouble:
// if only float(M), we will use float. The same for double.
if c.Flen != -1 && c.Decimal != -1 {
buf.WriteString(fmt.Sprintf("(%d,%d)", c.Flen, c.Decimal))
}
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
if c.Decimal != -1 && c.Decimal != 0 {
buf.WriteString(fmt.Sprintf("(%d)", c.Decimal))
}
default:
if c.Flen != -1 {
if c.Decimal == -1 {
buf.WriteString(fmt.Sprintf("(%d)", c.Flen))
} else {
buf.WriteString(fmt.Sprintf("(%d,%d)", c.Flen, c.Decimal))
}
}
}
desc := c.FieldType.CompactStr()
if mysql.HasUnsignedFlag(c.Flag) {
buf.WriteString(" UNSIGNED")
desc += " UNSIGNED"
}
return buf.String()
return desc
}
// NewColDesc returns a new ColDesc for a column.

View File

@ -198,5 +198,5 @@ func statement(sql string) stmt.Statement {
log.Debug("Compile", sql)
lexer := parser.NewLexer(sql)
parser.YYParse(lexer)
return lexer.Stmts()[0]
return lexer.Stmts()[0].(stmt.Statement)
}

View File

@ -104,11 +104,6 @@ func FastEval(v interface{}) interface{} {
}
}
// IsQualified returns whether name contains ".".
func IsQualified(name string) bool {
return strings.Contains(name, ".")
}
// Eval is a helper function evaluates expression v and do a panic if evaluating error.
func Eval(v Expression, ctx context.Context, env map[interface{}]interface{}) (y interface{}) {
var err error

View File

@ -43,16 +43,16 @@ func (*testFieldSuite) TestField(c *C) {
ft := types.NewFieldType(mysql.TypeLong)
ft.Flen = 20
ft.Flag |= mysql.UnsignedFlag | mysql.ZerofillFlag
c.Assert(ft.String(), Equals, "int (20) UNSIGNED ZEROFILL")
c.Assert(ft.String(), Equals, "int(20) UNSIGNED ZEROFILL")
ft = types.NewFieldType(mysql.TypeFloat)
ft.Flen = 20
ft.Decimal = 10
c.Assert(ft.String(), Equals, "float (20, 10)")
c.Assert(ft.String(), Equals, "float(20,10)")
ft = types.NewFieldType(mysql.TypeTimestamp)
ft.Decimal = 8
c.Assert(ft.String(), Equals, "timestamp (8)")
c.Assert(ft.String(), Equals, "timestamp(8)")
ft = types.NewFieldType(mysql.TypeVarchar)
ft.Flag |= mysql.BinaryFlag

View File

@ -24,7 +24,6 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/column"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/util/types"
)
const (
@ -89,14 +88,6 @@ func ColToResultField(col *column.Col, tableName string) *ResultField {
TableName: tableName,
OrgTableName: tableName,
}
if rf.Col.Flen == types.UnspecifiedLength {
rf.Col.Flen = 0
}
if rf.Col.Decimal == types.UnspecifiedLength {
rf.Col.Decimal = 0
}
// Keep things compatible for old clients.
// Refer to mysql-server/sql/protocol.cc send_result_set_metadata()
if rf.Tp == mysql.TypeVarchar {
@ -290,3 +281,8 @@ func CheckWildcardField(name string) (string, bool, error) {
_, table, field := SplitQualifiedName(name)
return table, field == "*", nil
}
// IsQualifiedName returns whether name contains "." or not.
func IsQualifiedName(name string) bool {
return strings.Contains(name, ".")
}

View File

@ -25,7 +25,7 @@ type Key []byte
// Next returns the next key in byte-order.
func (k Key) Next() Key {
// add \x0 to the end of key
// add 0x0 to the end of key
buf := make([]byte, len([]byte(k))+1)
copy(buf, []byte(k))
return buf

View File

@ -32,13 +32,13 @@ func (*testBtreeSuite) TestBtree(c *C) {
v := []interface{}{string(i + 1)}
t.Set(k, v)
}
for i := 0; i < 102400; i++ {
for i := 0; i < 1024; i++ {
k := []interface{}{i}
v := []interface{}{string(i)}
t.Set(k, v)
}
// Delete
for i := 512; i < 102400; i++ {
for i := 512; i < 1024; i++ {
k := []interface{}{i}
t.Delete(k)
}
@ -51,7 +51,7 @@ func (*testBtreeSuite) TestBtree(c *C) {
c.Assert(v[0], Equals, string(i))
}
// Get unexists key
for i := 512; i < 102400; i++ {
for i := 512; i < 1024; i++ {
k := []interface{}{i}
v, ok := t.Get(k)
c.Assert(ok, IsFalse)

View File

@ -117,7 +117,7 @@ func (iter *UnionIter) updateCur() {
// Next implements the Iterator Next interface.
func (iter *UnionIter) Next(f FnKeyCmp) (Iterator, error) {
if iter.curIsDirty == false {
if !iter.curIsDirty {
iter.snapshotNext()
} else {
iter.dirtyNext()
@ -129,7 +129,7 @@ func (iter *UnionIter) Next(f FnKeyCmp) (Iterator, error) {
// Value implements the Iterator Value interface.
// Multi columns
func (iter *UnionIter) Value() []byte {
if iter.curIsDirty == false {
if !iter.curIsDirty {
return iter.snapshotIt.Value()
}
return iter.dirtyIt.Value()
@ -137,7 +137,7 @@ func (iter *UnionIter) Value() []byte {
// Key implements the Iterator Key interface.
func (iter *UnionIter) Key() string {
if iter.curIsDirty == false {
if !iter.curIsDirty {
return string(DecodeKey([]byte(iter.snapshotIt.Key())))
}
return string(DecodeKey(iter.dirtyIt.Key()))

View File

@ -15,6 +15,7 @@ package meta
import (
"fmt"
"sync"
"github.com/juju/errors"
"github.com/ngaut/log"
@ -35,6 +36,10 @@ var (
globalIDKey = MakeMetaKey("mNextGlobalID")
)
var (
globalIDMutex sync.Mutex
)
// MakeMetaKey creates meta key
func MakeMetaKey(key string) []byte {
return append([]byte{0x0}, key...)
@ -67,11 +72,13 @@ func AutoIDKey(tableID int64) string {
if tableID == 0 {
log.Error("Invalid tableID")
}
return fmt.Sprintf("%s:%d_autoID", TableMetaPrefix, tableID)
return fmt.Sprintf("%s:%d_auto_id", TableMetaPrefix, tableID)
}
// GenGlobalID generates the next id in the store scope.
func GenGlobalID(store kv.Storage) (ID int64, err error) {
globalIDMutex.Lock()
defer globalIDMutex.Unlock()
err = kv.RunInNewTxn(store, true, func(txn kv.Transaction) error {
ID, err = GenID(txn, globalIDKey, 1)
if err != nil {

View File

@ -54,9 +54,9 @@ func (*testSuite) TestT(c *C) {
//For AutoIDKey
mkey = []byte(meta.AutoIDKey(1))
c.Assert(mkey, DeepEquals, meta.MakeMetaKey("mTable::1_autoID"))
c.Assert(mkey, DeepEquals, meta.MakeMetaKey("mTable::1_auto_id"))
mkey = []byte(meta.AutoIDKey(0))
c.Assert(mkey, DeepEquals, meta.MakeMetaKey("mTable::0_autoID"))
c.Assert(mkey, DeepEquals, meta.MakeMetaKey("mTable::0_auto_id"))
// For GenGlobalID
id, err = meta.GenGlobalID(store)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -36,9 +36,8 @@ func (e *SQLError) Error() string {
return fmt.Sprintf("ERROR %d (%s): %s", e.Code, e.State, e.Message)
}
// NewDefaultError generates a SQL error, with an error code and
// extra arguments for a message format specifier.
func NewDefaultError(errCode uint16, args ...interface{}) *SQLError {
// NewErr generates a SQL error, with an error code and default format specifier defined in MySQLErrName.
func NewErr(errCode uint16, args ...interface{}) *SQLError {
e := &SQLError{Code: errCode}
if s, ok := MySQLState[errCode]; ok {
@ -56,8 +55,8 @@ func NewDefaultError(errCode uint16, args ...interface{}) *SQLError {
return e
}
// NewError creates a SQL error, with an error code and error details.
func NewError(errCode uint16, message string) *SQLError {
// NewErrf creates a SQL error, with an error code and a format specifier
func NewErrf(errCode uint16, format string, args ...interface{}) *SQLError {
e := &SQLError{Code: errCode}
if s, ok := MySQLState[errCode]; ok {
@ -66,7 +65,7 @@ func NewError(errCode uint16, message string) *SQLError {
e.State = DefaultMySQLState
}
e.Message = message
e.Message = fmt.Sprintf(format, args...)
return e
}

View File

@ -23,15 +23,15 @@ type testSQLErrorSuite struct {
}
func (s *testSQLErrorSuite) TestSQLError(c *C) {
e := NewError(ErNoDbError, "no db error")
e := NewErrf(ErrNoDb, "no db error")
c.Assert(len(e.Error()), Greater, 0)
e = NewError(0, "customized error")
e = NewErrf(0, "customized error")
c.Assert(len(e.Error()), Greater, 0)
e = NewDefaultError(ErNoDbError)
e = NewErr(ErrNoDb)
c.Assert(len(e.Error()), Greater, 0)
e = NewDefaultError(0, "customized error")
e = NewErr(0, "customized error")
c.Assert(len(e.Error()), Greater, 0)
}

View File

@ -21,229 +21,229 @@ const (
// MySQLState maps error code to MySQL SQLSTATE value.
// The values are taken from ANSI SQL and ODBC and are more standardized.
var MySQLState = map[uint16]string{
ErDupKey: "23000",
ErOutofmemory: "HY001",
ErOutOfSortmemory: "HY001",
ErConCountError: "08004",
ErBadHostError: "08S01",
ErHandshakeError: "08S01",
ErDbaccessDeniedError: "42000",
ErAccessDeniedError: "28000",
ErNoDbError: "3D000",
ErUnknownComError: "08S01",
ErBadNullError: "23000",
ErBadDbError: "42000",
ErTableExistsError: "42S01",
ErBadTableError: "42S02",
ErNonUniqError: "23000",
ErServerShutdown: "08S01",
ErBadFieldError: "42S22",
ErWrongFieldWithGroup: "42000",
ErWrongSumSelect: "42000",
ErWrongGroupField: "42000",
ErWrongValueCount: "21S01",
ErTooLongIdent: "42000",
ErDupFieldname: "42S21",
ErDupKeyname: "42000",
ErDupEntry: "23000",
ErWrongFieldSpec: "42000",
ErParseError: "42000",
ErEmptyQuery: "42000",
ErNonuniqTable: "42000",
ErInvalidDefault: "42000",
ErMultiplePriKey: "42000",
ErTooManyKeys: "42000",
ErTooManyKeyParts: "42000",
ErTooLongKey: "42000",
ErKeyColumnDoesNotExits: "42000",
ErBlobUsedAsKey: "42000",
ErTooBigFieldlength: "42000",
ErWrongAutoKey: "42000",
ErForcingClose: "08S01",
ErIpsockError: "08S01",
ErNoSuchIndex: "42S12",
ErWrongFieldTerminators: "42000",
ErBlobsAndNoTerminated: "42000",
ErCantRemoveAllFields: "42000",
ErCantDropFieldOrKey: "42000",
ErBlobCantHaveDefault: "42000",
ErWrongDbName: "42000",
ErWrongTableName: "42000",
ErTooBigSelect: "42000",
ErUnknownProcedure: "42000",
ErWrongParamcountToProcedure: "42000",
ErUnknownTable: "42S02",
ErFieldSpecifiedTwice: "42000",
ErUnsupportedExtension: "42000",
ErTableMustHaveColumns: "42000",
ErUnknownCharacterSet: "42000",
ErTooBigRowsize: "42000",
ErWrongOuterJoin: "42000",
ErNullColumnInIndex: "42000",
ErPasswordAnonymousUser: "42000",
ErPasswordNotAllowed: "42000",
ErPasswordNoMatch: "42000",
ErWrongValueCountOnRow: "21S01",
ErInvalidUseOfNull: "22004",
ErRegexpError: "42000",
ErMixOfGroupFuncAndFields: "42000",
ErNonexistingGrant: "42000",
ErTableaccessDeniedError: "42000",
ErColumnaccessDeniedError: "42000",
ErIllegalGrantForTable: "42000",
ErGrantWrongHostOrUser: "42000",
ErNoSuchTable: "42S02",
ErNonexistingTableGrant: "42000",
ErNotAllowedCommand: "42000",
ErSyntaxError: "42000",
ErAbortingConnection: "08S01",
ErNetPacketTooLarge: "08S01",
ErNetReadErrorFromPipe: "08S01",
ErNetFcntlError: "08S01",
ErNetPacketsOutOfOrder: "08S01",
ErNetUncompressError: "08S01",
ErNetReadError: "08S01",
ErNetReadInterrupted: "08S01",
ErNetErrorOnWrite: "08S01",
ErNetWriteInterrupted: "08S01",
ErTooLongString: "42000",
ErTableCantHandleBlob: "42000",
ErTableCantHandleAutoIncrement: "42000",
ErWrongColumnName: "42000",
ErWrongKeyColumn: "42000",
ErDupUnique: "23000",
ErBlobKeyWithoutLength: "42000",
ErPrimaryCantHaveNull: "42000",
ErTooManyRows: "42000",
ErRequiresPrimaryKey: "42000",
ErKeyDoesNotExits: "42000",
ErCheckNoSuchTable: "42000",
ErCheckNotImplemented: "42000",
ErCantDoThisDuringAnTransaction: "25000",
ErNewAbortingConnection: "08S01",
ErMasterNetRead: "08S01",
ErMasterNetWrite: "08S01",
ErTooManyUserConnections: "42000",
ErReadOnlyTransaction: "25000",
ErNoPermissionToCreateUser: "42000",
ErLockDeadlock: "40001",
ErNoReferencedRow: "23000",
ErRowIsReferenced: "23000",
ErConnectToMaster: "08S01",
ErWrongNumberOfColumnsInSelect: "21000",
ErUserLimitReached: "42000",
ErSpecificAccessDeniedError: "42000",
ErNoDefault: "42000",
ErWrongValueForVar: "42000",
ErWrongTypeForVar: "42000",
ErCantUseOptionHere: "42000",
ErNotSupportedYet: "42000",
ErWrongFkDef: "42000",
ErOperandColumns: "21000",
ErSubqueryNo1Row: "21000",
ErIllegalReference: "42S22",
ErDerivedMustHaveAlias: "42000",
ErSelectReduced: "01000",
ErTablenameNotAllowedHere: "42000",
ErNotSupportedAuthMode: "08004",
ErSpatialCantHaveNull: "42000",
ErCollationCharsetMismatch: "42000",
ErWarnTooFewRecords: "01000",
ErWarnTooManyRecords: "01000",
ErWarnNullToNotnull: "22004",
ErWarnDataOutOfRange: "22003",
WarnDataTruncated: "01000",
ErWrongNameForIndex: "42000",
ErWrongNameForCatalog: "42000",
ErUnknownStorageEngine: "42000",
ErTruncatedWrongValue: "22007",
ErSpNoRecursiveCreate: "2F003",
ErSpAlreadyExists: "42000",
ErSpDoesNotExist: "42000",
ErSpLilabelMismatch: "42000",
ErSpLabelRedefine: "42000",
ErSpLabelMismatch: "42000",
ErSpUninitVar: "01000",
ErSpBadselect: "0A000",
ErSpBadreturn: "42000",
ErSpBadstatement: "0A000",
ErUpdateLogDeprecatedIgnored: "42000",
ErUpdateLogDeprecatedTranslated: "42000",
ErQueryInterrupted: "70100",
ErSpWrongNoOfArgs: "42000",
ErSpCondMismatch: "42000",
ErSpNoreturn: "42000",
ErSpNoreturnend: "2F005",
ErSpBadCursorQuery: "42000",
ErSpBadCursorSelect: "42000",
ErSpCursorMismatch: "42000",
ErSpCursorAlreadyOpen: "24000",
ErSpCursorNotOpen: "24000",
ErSpUndeclaredVar: "42000",
ErSpFetchNoData: "02000",
ErSpDupParam: "42000",
ErSpDupVar: "42000",
ErSpDupCond: "42000",
ErSpDupCurs: "42000",
ErSpSubselectNyi: "0A000",
ErStmtNotAllowedInSfOrTrg: "0A000",
ErSpVarcondAfterCurshndlr: "42000",
ErSpCursorAfterHandler: "42000",
ErSpCaseNotFound: "20000",
ErDivisionByZero: "22012",
ErIllegalValueForType: "22007",
ErProcaccessDeniedError: "42000",
ErXaerNota: "XAE04",
ErXaerInval: "XAE05",
ErXaerRmfail: "XAE07",
ErXaerOutside: "XAE09",
ErXaerRmerr: "XAE03",
ErXaRbrollback: "XA100",
ErNonexistingProcGrant: "42000",
ErDataTooLong: "22001",
ErSpBadSQLstate: "42000",
ErCantCreateUserWithGrant: "42000",
ErSpDupHandler: "42000",
ErSpNotVarArg: "42000",
ErSpNoRetset: "0A000",
ErCantCreateGeometryObject: "22003",
ErTooBigScale: "42000",
ErTooBigPrecision: "42000",
ErMBiggerThanD: "42000",
ErTooLongBody: "42000",
ErTooBigDisplaywidth: "42000",
ErXaerDupid: "XAE08",
ErDatetimeFunctionOverflow: "22008",
ErRowIsReferenced2: "23000",
ErNoReferencedRow2: "23000",
ErSpBadVarShadow: "42000",
ErSpWrongName: "42000",
ErSpNoAggregate: "42000",
ErMaxPreparedStmtCountReached: "42000",
ErNonGroupingFieldUsed: "42000",
ErForeignDuplicateKeyOldUnused: "23000",
ErCantChangeTxCharacteristics: "25001",
ErWrongParamcountToNativeFct: "42000",
ErWrongParametersToNativeFct: "42000",
ErWrongParametersToStoredFct: "42000",
ErDupEntryWithKeyName: "23000",
ErXaRbtimeout: "XA106",
ErXaRbdeadlock: "XA102",
ErFuncInexistentNameCollision: "42000",
ErDupSignalSet: "42000",
ErSignalWarn: "01000",
ErSignalNotFound: "02000",
ErSignalException: "HY000",
ErResignalWithoutActiveHandler: "0K000",
ErSpatialMustHaveGeomCol: "42000",
ErDataOutOfRange: "22003",
ErAccessDeniedNoPasswordError: "28000",
ErTruncateIllegalFk: "42000",
ErDaInvalidConditionNumber: "35000",
ErForeignDuplicateKeyWithChildInfo: "23000",
ErForeignDuplicateKeyWithoutChildInfo: "23000",
ErCantExecuteInReadOnlyTransaction: "25006",
ErAlterOperationNotSupported: "0A000",
ErAlterOperationNotSupportedReason: "0A000",
ErDupUnknownInIndex: "23000",
ErrDupKey: "23000",
ErrOutofmemory: "HY001",
ErrOutOfSortmemory: "HY001",
ErrConCount: "08004",
ErrBadHost: "08S01",
ErrHandshake: "08S01",
ErrDbaccessDenied: "42000",
ErrAccessDenied: "28000",
ErrNoDb: "3D000",
ErrUnknownCom: "08S01",
ErrBadNull: "23000",
ErrBadDb: "42000",
ErrTableExists: "42S01",
ErrBadTable: "42S02",
ErrNonUniq: "23000",
ErrServerShutdown: "08S01",
ErrBadField: "42S22",
ErrWrongFieldWithGroup: "42000",
ErrWrongSumSelect: "42000",
ErrWrongGroupField: "42000",
ErrWrongValueCount: "21S01",
ErrTooLongIdent: "42000",
ErrDupFieldname: "42S21",
ErrDupKeyname: "42000",
ErrDupEntry: "23000",
ErrWrongFieldSpec: "42000",
ErrParse: "42000",
ErrEmptyQuery: "42000",
ErrNonuniqTable: "42000",
ErrInvalidDefault: "42000",
ErrMultiplePriKey: "42000",
ErrTooManyKeys: "42000",
ErrTooManyKeyParts: "42000",
ErrTooLongKey: "42000",
ErrKeyColumnDoesNotExits: "42000",
ErrBlobUsedAsKey: "42000",
ErrTooBigFieldlength: "42000",
ErrWrongAutoKey: "42000",
ErrForcingClose: "08S01",
ErrIpsock: "08S01",
ErrNoSuchIndex: "42S12",
ErrWrongFieldTerminators: "42000",
ErrBlobsAndNoTerminated: "42000",
ErrCantRemoveAllFields: "42000",
ErrCantDropFieldOrKey: "42000",
ErrBlobCantHaveDefault: "42000",
ErrWrongDbName: "42000",
ErrWrongTableName: "42000",
ErrTooBigSelect: "42000",
ErrUnknownProcedure: "42000",
ErrWrongParamcountToProcedure: "42000",
ErrUnknownTable: "42S02",
ErrFieldSpecifiedTwice: "42000",
ErrUnsupportedExtension: "42000",
ErrTableMustHaveColumns: "42000",
ErrUnknownCharacterSet: "42000",
ErrTooBigRowsize: "42000",
ErrWrongOuterJoin: "42000",
ErrNullColumnInIndex: "42000",
ErrPasswordAnonymousUser: "42000",
ErrPasswordNotAllowed: "42000",
ErrPasswordNoMatch: "42000",
ErrWrongValueCountOnRow: "21S01",
ErrInvalidUseOfNull: "22004",
ErrRegexp: "42000",
ErrMixOfGroupFuncAndFields: "42000",
ErrNonexistingGrant: "42000",
ErrTableaccessDenied: "42000",
ErrColumnaccessDenied: "42000",
ErrIllegalGrantForTable: "42000",
ErrGrantWrongHostOrUser: "42000",
ErrNoSuchTable: "42S02",
ErrNonexistingTableGrant: "42000",
ErrNotAllowedCommand: "42000",
ErrSyntax: "42000",
ErrAbortingConnection: "08S01",
ErrNetPacketTooLarge: "08S01",
ErrNetReadErrorFromPipe: "08S01",
ErrNetFcntl: "08S01",
ErrNetPacketsOutOfOrder: "08S01",
ErrNetUncompress: "08S01",
ErrNetRead: "08S01",
ErrNetReadInterrupted: "08S01",
ErrNetErrorOnWrite: "08S01",
ErrNetWriteInterrupted: "08S01",
ErrTooLongString: "42000",
ErrTableCantHandleBlob: "42000",
ErrTableCantHandleAutoIncrement: "42000",
ErrWrongColumnName: "42000",
ErrWrongKeyColumn: "42000",
ErrDupUnique: "23000",
ErrBlobKeyWithoutLength: "42000",
ErrPrimaryCantHaveNull: "42000",
ErrTooManyRows: "42000",
ErrRequiresPrimaryKey: "42000",
ErrKeyDoesNotExits: "42000",
ErrCheckNoSuchTable: "42000",
ErrCheckNotImplemented: "42000",
ErrCantDoThisDuringAnTransaction: "25000",
ErrNewAbortingConnection: "08S01",
ErrMasterNetRead: "08S01",
ErrMasterNetWrite: "08S01",
ErrTooManyUserConnections: "42000",
ErrReadOnlyTransaction: "25000",
ErrNoPermissionToCreateUser: "42000",
ErrLockDeadlock: "40001",
ErrNoReferencedRow: "23000",
ErrRowIsReferenced: "23000",
ErrConnectToMaster: "08S01",
ErrWrongNumberOfColumnsInSelect: "21000",
ErrUserLimitReached: "42000",
ErrSpecificAccessDenied: "42000",
ErrNoDefault: "42000",
ErrWrongValueForVar: "42000",
ErrWrongTypeForVar: "42000",
ErrCantUseOptionHere: "42000",
ErrNotSupportedYet: "42000",
ErrWrongFkDef: "42000",
ErrOperandColumns: "21000",
ErrSubqueryNo1Row: "21000",
ErrIllegalReference: "42S22",
ErrDerivedMustHaveAlias: "42000",
ErrSelectReduced: "01000",
ErrTablenameNotAllowedHere: "42000",
ErrNotSupportedAuthMode: "08004",
ErrSpatialCantHaveNull: "42000",
ErrCollationCharsetMismatch: "42000",
ErrWarnTooFewRecords: "01000",
ErrWarnTooManyRecords: "01000",
ErrWarnNullToNotnull: "22004",
ErrWarnDataOutOfRange: "22003",
WarnDataTruncated: "01000",
ErrWrongNameForIndex: "42000",
ErrWrongNameForCatalog: "42000",
ErrUnknownStorageEngine: "42000",
ErrTruncatedWrongValue: "22007",
ErrSpNoRecursiveCreate: "2F003",
ErrSpAlreadyExists: "42000",
ErrSpDoesNotExist: "42000",
ErrSpLilabelMismatch: "42000",
ErrSpLabelRedefine: "42000",
ErrSpLabelMismatch: "42000",
ErrSpUninitVar: "01000",
ErrSpBadselect: "0A000",
ErrSpBadreturn: "42000",
ErrSpBadstatement: "0A000",
ErrUpdateLogDeprecatedIgnored: "42000",
ErrUpdateLogDeprecatedTranslated: "42000",
ErrQueryInterrupted: "70100",
ErrSpWrongNoOfArgs: "42000",
ErrSpCondMismatch: "42000",
ErrSpNoreturn: "42000",
ErrSpNoreturnend: "2F005",
ErrSpBadCursorQuery: "42000",
ErrSpBadCursorSelect: "42000",
ErrSpCursorMismatch: "42000",
ErrSpCursorAlreadyOpen: "24000",
ErrSpCursorNotOpen: "24000",
ErrSpUndeclaredVar: "42000",
ErrSpFetchNoData: "02000",
ErrSpDupParam: "42000",
ErrSpDupVar: "42000",
ErrSpDupCond: "42000",
ErrSpDupCurs: "42000",
ErrSpSubselectNyi: "0A000",
ErrStmtNotAllowedInSfOrTrg: "0A000",
ErrSpVarcondAfterCurshndlr: "42000",
ErrSpCursorAfterHandler: "42000",
ErrSpCaseNotFound: "20000",
ErrDivisionByZero: "22012",
ErrIllegalValueForType: "22007",
ErrProcaccessDenied: "42000",
ErrXaerNota: "XAE04",
ErrXaerInval: "XAE05",
ErrXaerRmfail: "XAE07",
ErrXaerOutside: "XAE09",
ErrXaerRmerr: "XAE03",
ErrXaRbrollback: "XA100",
ErrNonexistingProcGrant: "42000",
ErrDataTooLong: "22001",
ErrSpBadSQLstate: "42000",
ErrCantCreateUserWithGrant: "42000",
ErrSpDupHandler: "42000",
ErrSpNotVarArg: "42000",
ErrSpNoRetset: "0A000",
ErrCantCreateGeometryObject: "22003",
ErrTooBigScale: "42000",
ErrTooBigPrecision: "42000",
ErrMBiggerThanD: "42000",
ErrTooLongBody: "42000",
ErrTooBigDisplaywidth: "42000",
ErrXaerDupid: "XAE08",
ErrDatetimeFunctionOverflow: "22008",
ErrRowIsReferenced2: "23000",
ErrNoReferencedRow2: "23000",
ErrSpBadVarShadow: "42000",
ErrSpWrongName: "42000",
ErrSpNoAggregate: "42000",
ErrMaxPreparedStmtCountReached: "42000",
ErrNonGroupingFieldUsed: "42000",
ErrForeignDuplicateKeyOldUnused: "23000",
ErrCantChangeTxCharacteristics: "25001",
ErrWrongParamcountToNativeFct: "42000",
ErrWrongParametersToNativeFct: "42000",
ErrWrongParametersToStoredFct: "42000",
ErrDupEntryWithKeyName: "23000",
ErrXaRbtimeout: "XA106",
ErrXaRbdeadlock: "XA102",
ErrFuncInexistentNameCollision: "42000",
ErrDupSignalSet: "42000",
ErrSignalWarn: "01000",
ErrSignalNotFound: "02000",
ErrSignalException: "HY000",
ErrResignalWithoutActiveHandler: "0K000",
ErrSpatialMustHaveGeomCol: "42000",
ErrDataOutOfRange: "22003",
ErrAccessDeniedNoPassword: "28000",
ErrTruncateIllegalFk: "42000",
ErrDaInvalidConditionNumber: "35000",
ErrForeignDuplicateKeyWithChildInfo: "23000",
ErrForeignDuplicateKeyWithoutChildInfo: "23000",
ErrCantExecuteInReadOnlyTransaction: "25006",
ErrAlterOperationNotSupported: "0A000",
ErrAlterOperationNotSupportedReason: "0A000",
ErrDupUnknownInIndex: "23000",
}

View File

@ -17,6 +17,7 @@ package mysqldef
// Call this when no Flen assigned in ddl.
// or column value is calculated from an expression.
// For example: "select count(*) from t;", the column type is int64 and Flen in ResultField will be 21.
// See: https://dev.mysql.com/doc/refman/5.7/en/storage-requirements.html
func GetDefaultFieldLength(tp byte) int {
switch tp {
case TypeTiny:
@ -29,8 +30,25 @@ func GetDefaultFieldLength(tp byte) int {
return 11
case TypeLonglong:
return 21
case TypeDecimal:
// See: https://dev.mysql.com/doc/refman/5.7/en/fixed-point-types.html
return 10
case TypeBit, TypeBlob:
return -1
default:
//TODO: add more types
return 0
return -1
}
}
// GetDefaultDecimal returns the default decimal length for column.
func GetDefaultDecimal(tp byte) int {
switch tp {
case TypeDecimal:
// See: https://dev.mysql.com/doc/refman/5.7/en/fixed-point-types.html
return 0
default:
//TODO: add more types
return -1
}
}

View File

@ -25,7 +25,9 @@ func TestGetFieldLength(t *testing.T) {
{TypeInt24, 9},
{TypeLong, 11},
{TypeLonglong, 21},
{TypeNull, 0},
{TypeBit, -1},
{TypeBlob, -1},
{TypeNull, -1},
}
for _, test := range tbl {

32
optimizer/optimizer.go Normal file
View File

@ -0,0 +1,32 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/stmt"
)
// Compile compiles a ast.Node into a executable statement.
func Compile(node ast.Node) (stmt.Statement, error) {
switch v := node.(type) {
case *ast.SetStmt:
return compileSet(v)
}
return nil, nil
}
func compileSet(aset *ast.SetStmt) (stmt.Statement, error) {
return nil, nil
}

View File

@ -181,9 +181,12 @@ func ColumnDefToCol(offset int, colDef *ColumnDef) (*column.Col, []*TableConstra
}
// If flen is not assigned, assigned it by type.
if col.Flen == 0 {
if col.Flen == types.UnspecifiedLength {
col.Flen = mysql.GetDefaultFieldLength(col.Tp)
}
if col.Decimal == types.UnspecifiedLength {
col.Decimal = mysql.GetDefaultDecimal(col.Tp)
}
setOnUpdateNow := false
hasDefaultValue := false

View File

@ -30,6 +30,7 @@ import (
"strings"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/parser/coldef"
"github.com/pingcap/tidb/ddl"
"github.com/pingcap/tidb/expression"
@ -2931,15 +2932,29 @@ RollbackStmt:
}
SelectStmt:
"SELECT" SelectStmtOpts SelectStmtFieldList FromDual SelectStmtLimit SelectLockOpt
"SELECT" SelectStmtOpts SelectStmtFieldList SelectStmtLimit SelectLockOpt
{
$$ = &stmts.SelectStmt {
Distinct: $2.(bool),
Fields: $3.([]*field.Field),
From: nil,
Lock: $6.(coldef.LockType),
Lock: $5.(coldef.LockType),
}
}
| "SELECT" SelectStmtOpts SelectStmtFieldList FromDual WhereClauseOptional SelectStmtLimit SelectLockOpt
{
st := &stmts.SelectStmt {
Distinct: $2.(bool),
Fields: $3.([]*field.Field),
From: nil,
Lock: $7.(coldef.LockType),
}
if $5 != nil {
st.Where = &rsets.WhereRset{Expr: $5.(expression.Expression)}
}
$$ = st
}
| "SELECT" SelectStmtOpts SelectStmtFieldList "FROM"
FromClause WhereClauseOptional SelectStmtGroup HavingClause SelectStmtOrder
SelectStmtLimit SelectLockOpt
@ -2977,8 +2992,7 @@ SelectStmt:
}
FromDual:
/* Empty */
| "FROM" "DUAL"
"FROM" "DUAL"
FromClause:
@ -3523,9 +3537,15 @@ StatementList:
Statement
{
if $1 != nil {
s := $1.(stmt.Statement)
s.SetText(yylex.(*lexer).stmtText())
yylex.(*lexer).list = []stmt.Statement{ s }
n, ok := $1.(ast.Node)
if ok {
n.SetText(yylex.(*lexer).stmtText())
yylex.(*lexer).list = []interface{}{n}
} else {
s := $1.(stmt.Statement)
s.SetText(yylex.(*lexer).stmtText())
yylex.(*lexer).list = []interface{}{s}
}
}
}
| StatementList ';' Statement
@ -3861,21 +3881,25 @@ BitValueType:
}
StringType:
NationalOpt "CHAR" FieldLen OptBinary
NationalOpt "CHAR" FieldLen OptBinary OptCharset OptCollate
{
x := types.NewFieldType(mysql.TypeString)
x.Flen = $3.(int)
if $4.(bool) {
x.Flag |= mysql.BinaryFlag
}
x.Charset = $5.(string)
x.Collate = $6.(string)
$$ = x
}
| NationalOpt "CHAR" OptBinary
| NationalOpt "CHAR" OptBinary OptCharset OptCollate
{
x := types.NewFieldType(mysql.TypeString)
if $3.(bool) {
x.Flag |= mysql.BinaryFlag
}
x.Charset = $4.(string)
x.Collate = $5.(string)
$$ = x
}
| NationalOpt "VARCHAR" FieldLen OptBinary OptCharset OptCollate

View File

@ -278,6 +278,8 @@ func (s *testParserSuite) TestDMLStmt(c *C) {
// For dual
{"select 1 from dual", true},
{"select 1 from dual limit 1", true},
{"select 1 where exists (select 2)", false},
{"select 1 from dual where not exists (select 2)", true},
// For show create table
{"show create table test.t", true},
@ -442,6 +444,11 @@ func (s *testParserSuite) TestDDL(c *C) {
{"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) // foo", true},
{"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) /* foo */", true},
{"CREATE TABLE foo /* foo */ (a SMALLINT UNSIGNED, b INT UNSIGNED) /* foo */", true},
{"CREATE TABLE foo (name CHAR(50) BINARY)", true},
{"CREATE TABLE foo (name CHAR(50) COLLATE utf8_bin)", true},
{"CREATE TABLE foo (name CHAR(50) CHARACTER SET utf8)", true},
{"CREATE TABLE foo (name CHAR(50) BINARY CHARACTER SET utf8 COLLATE utf8_bin)", true},
{"CREATE TABLE foo (a.b, b);", false},
{"CREATE TABLE foo (a, b.c);", false},
// For table option

View File

@ -26,7 +26,6 @@ import (
"strings"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/stmt"
mysql "github.com/pingcap/tidb/mysqldef"
)
@ -39,7 +38,7 @@ type lexer struct {
inj int
lcol int
line int
list []stmt.Statement
list []interface{}
ncol int
nline int
sc int
@ -73,7 +72,7 @@ func (l *lexer) Errors() []error {
return l.errs
}
func (l *lexer) Stmts() []stmt.Statement{
func (l *lexer) Stmts() []interface{}{
return l.list
}

View File

@ -32,6 +32,7 @@ type HavingPlan struct {
Src plan.Plan
Expr expression.Expression
evalArgs map[interface{}]interface{}
*SelectList
}
// Explain implements plan.Plan Explain interface.

View File

@ -30,6 +30,10 @@ var _ = Suite(&testHavingPlan{})
func (t *testHavingPlan) TestHaving(c *C) {
tblPlan := &testTablePlan{groupByTestData, []string{"id", "name"}, 0}
havingPlan := &plans.HavingPlan{
SelectList: &plans.SelectList{
HiddenFieldOffset: len(tblPlan.GetFields()),
ResultFields: tblPlan.GetFields(),
},
Src: tblPlan,
Expr: &expression.BinaryOperation{
Op: opcode.GE,

View File

@ -194,6 +194,7 @@ func (r *indexPlan) GetFields() []*field.ResultField {
// Filter implements plan.Plan Filter interface.
// Filter merges BinaryOperations and determines the lower and upper bound.
func (r *indexPlan) Filter(ctx context.Context, expr expression.Expression) (plan.Plan, bool, error) {
var spans []*indexSpan
switch x := expr.(type) {
case *expression.BinaryOperation:
ok, name, val, err := x.IsIdentCompareVal()
@ -214,14 +215,12 @@ func (r *indexPlan) Filter(ctx context.Context, expr expression.Expression) (pla
if err != nil {
return nil, false, errors.Trace(err)
}
r.spans = filterSpans(r.spans, toSpans(x.Op, val, seekVal))
return r, true, nil
spans = filterSpans(r.spans, toSpans(x.Op, val, seekVal))
case *expression.Ident:
if r.col.Name.L != x.L {
break
}
r.spans = filterSpans(r.spans, toSpans(opcode.GE, minNotNullVal, nil))
return r, true, nil
spans = filterSpans(r.spans, toSpans(opcode.GE, minNotNullVal, nil))
case *expression.UnaryOperation:
if x.Op != '!' {
break
@ -234,16 +233,25 @@ func (r *indexPlan) Filter(ctx context.Context, expr expression.Expression) (pla
if r.col.Name.L != cname.L {
break
}
r.spans = filterSpans(r.spans, toSpans(opcode.EQ, nil, nil))
return r, true, nil
spans = filterSpans(r.spans, toSpans(opcode.EQ, nil, nil))
}
return r, false, nil
if spans == nil {
return r, false, nil
}
return &indexPlan{
src: r.src,
col: r.col,
idxName: r.idxName,
idx: r.idx,
spans: spans,
}, true, nil
}
// return the intersection range between origin and filter.
func filterSpans(origin []*indexSpan, filter []*indexSpan) []*indexSpan {
var newSpans []*indexSpan
newSpans := make([]*indexSpan, 0, len(filter))
for _, fSpan := range filter {
for _, oSpan := range origin {
newSpan := oSpan.cutOffLow(fSpan.lowVal, fSpan.lowExclude)

View File

@ -361,22 +361,21 @@ func (isp *InfoSchemaPlan) fetchColumns(schemas []*model.DBInfo) {
if decimal == types.UnspecifiedLength {
decimal = 0
}
dataType := types.TypeToStr(col.Tp, col.Charset == charset.CharsetBin)
columnType := fmt.Sprintf("%s(%d)", dataType, colLen)
columnType := col.FieldType.CompactStr()
columnDesc := column.NewColDesc(&column.Col{ColumnInfo: *col})
var columnDefault interface{}
if columnDesc.DefaultValue != nil {
columnDefault = fmt.Sprintf("%v", columnDesc.DefaultValue)
}
record := []interface{}{
catalogVal, // TABLE_CATALOG
schema.Name.O, // TABLE_SCHEMA
table.Name.O, // TABLE_NAME
col.Name.O, // COLUMN_NAME
i + 1, // ORIGINAL_POSITION
columnDefault, // COLUMN_DEFAULT
columnDesc.Null, // IS_NULLABLE
types.TypeToStr(col.Tp, col.Charset == charset.CharsetBin), // DATA_TYPE
catalogVal, // TABLE_CATALOG
schema.Name.O, // TABLE_SCHEMA
table.Name.O, // TABLE_NAME
col.Name.O, // COLUMN_NAME
i + 1, // ORIGINAL_POSITION
columnDefault, // COLUMN_DEFAULT
columnDesc.Null, // IS_NULLABLE
types.TypeToStr(col.Tp, col.Charset), // DATA_TYPE
colLen, // CHARACTER_MAXIMUM_LENGTH
colLen, // CHARACTOR_OCTET_LENGTH
decimal, // NUMERIC_PRECISION

View File

@ -62,6 +62,9 @@ func mustQuery(c *C, currDB *sql.DB, s string) int {
c.Assert(err, IsNil)
cnt++
}
c.Assert(r.Err(), IsNil)
r.Close()
mustCommit(c, tx)
return cnt
}

View File

@ -289,6 +289,7 @@ func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error)
if r.curRow == nil {
return nil, nil
}
if r.On != nil {
tempExpr := r.On.Clone()
visitor := NewIdentEvalVisitor(r.Left.GetFields(), r.curRow.Data)
@ -316,7 +317,11 @@ func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error)
r.tempPlan.Close()
continue
}
joinedRow := append(r.curRow.Data, rightRow.Data...)
// To prevent outer modify the slice. See comment above.
joinedRow := make([]interface{}, 0, len(r.curRow.Data)+len(rightRow.Data))
joinedRow = append(append(joinedRow, r.curRow.Data...), rightRow.Data...)
if r.On != nil {
r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
return GetIdentValue(name, r.Fields, joinedRow, field.DefaultFieldFlag)

View File

@ -59,10 +59,13 @@ func (s *SelectList) updateFields(table string, resultFields []*field.ResultFiel
func createEmptyResultField(f *field.Field) *field.ResultField {
result := &field.ResultField{}
// Set origin name
result.ColumnInfo.Name = model.NewCIStr(f.Expr.String())
if len(f.AsName) > 0 {
result.Name = f.AsName
} else {
result.Name = f.Expr.String()
result.Name = result.ColumnInfo.Name.O
}
return result
}
@ -111,7 +114,8 @@ func (s *SelectList) UpdateAggFields(expr expression.Expression, tableFields []*
// We must add aggregate function to hidden select list
// and use a position expression to fetch its value later.
exprName := expr.String()
if !field.ContainFieldName(exprName, s.ResultFields, field.CheckFieldFlag) {
idx := field.GetResultFieldIndex(exprName, s.ResultFields, field.CheckFieldFlag)
if len(idx) == 0 {
f := &field.Field{Expr: expr}
resultField := &field.ResultField{Name: exprName}
s.AddField(f, resultField)
@ -119,7 +123,8 @@ func (s *SelectList) UpdateAggFields(expr expression.Expression, tableFields []*
return &expression.Position{N: len(s.Fields), Name: exprName}, nil
}
return nil, nil
// select list has this field, use it directly.
return &expression.Position{N: idx[0] + 1, Name: exprName}, nil
}
// CloneHiddenField checks and clones field and result field from table fields,
@ -140,6 +145,55 @@ func (s *SelectList) CloneHiddenField(name string, tableFields []*field.ResultFi
return false
}
// CheckReferAmbiguous checks whether an identifier reference is ambiguous or not in select list.
// e,g, "select c1 as a, c2 as a from t group by a" is ambiguous,
// but "select c1 as a, c1 as a from t group by a" is not.
// For MySQL "select c1 as a, c2 + 1 as a from t group by a" is not ambiguous too,
// so we will only check identifier too.
// If no ambiguous, -1 means expr refers none in select list, else an index in select list returns.
func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error) {
if _, ok := expr.(*expression.Ident); !ok {
return -1, nil
}
name := expr.String()
if field.IsQualifiedName(name) {
// name is qualified, no need to check
return -1, nil
}
lastIndex := -1
// only check origin select list, no hidden field.
for i := 0; i < s.HiddenFieldOffset; i++ {
if !strings.EqualFold(s.ResultFields[i].Name, name) {
continue
} else if _, ok := s.Fields[i].Expr.(*expression.Ident); !ok {
// not identfier, no check
continue
}
if lastIndex == -1 {
// first match, continue
lastIndex = i
continue
}
// check origin name, e,g. "select c1 as c2, c2 from t group by c2" is ambiguous.
if s.ResultFields[i].ColumnInfo.Name.L != s.ResultFields[lastIndex].ColumnInfo.Name.L {
return -1, errors.Errorf("refer %s is ambiguous", expr)
}
// check table name, e.g, "select t.c1, c1 from t group by c1" is not ambiguous.
if s.ResultFields[i].TableName != s.ResultFields[lastIndex].TableName {
return -1, errors.Errorf("refer %s is ambiguous", expr)
}
// TODO: check database name if possible.
}
return lastIndex, nil
}
// ResolveSelectList gets fields and result fields from selectFields and srcFields,
// including field validity check and wildcard field processing.
func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultField) (*SelectList, error) {
@ -180,8 +234,16 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie
}
var result *field.ResultField
if err = field.CheckAllFieldNames(names, srcFields, field.DefaultFieldFlag); err != nil {
return nil, errors.Trace(err)
for _, name := range names {
idx := field.GetResultFieldIndex(name, srcFields, field.DefaultFieldFlag)
if len(idx) > 1 {
return nil, errors.Errorf("ambiguous field %s", name)
}
// TODO: must check in outer query too.
if len(idx) == 0 {
return nil, errors.Errorf("unknown field %s", name)
}
}
if _, ok := v.Expr.(*expression.Ident); ok {

View File

@ -0,0 +1,111 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package plans_test
import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/plan/plans"
)
type testSelectListSuite struct{}
var _ = Suite(&testSelectListSuite{})
func (s *testSelectListSuite) createTestResultFields(colNames []string) []*field.ResultField {
fs := make([]*field.ResultField, 0, len(colNames))
for i := 0; i < len(colNames); i++ {
f := &field.ResultField{
TableName: "t",
Name: colNames[i],
}
f.ColumnInfo.Name = model.NewCIStr(colNames[i])
fs = append(fs, f)
}
return fs
}
func (s *testSelectListSuite) TestAmbiguous(c *C) {
type pair struct {
Name string
AsName string
}
createResultFields := func(names []string) []*field.ResultField {
fs := make([]*field.ResultField, 0, len(names))
for i := 0; i < len(names); i++ {
f := &field.ResultField{
TableName: "t",
Name: names[i],
}
f.ColumnInfo.Name = model.NewCIStr(names[i])
fs = append(fs, f)
}
return fs
}
createFields := func(ps []pair) []*field.Field {
fields := make([]*field.Field, len(ps))
for i, f := range ps {
fields[i] = &field.Field{
Expr: &expression.Ident{
CIStr: model.NewCIStr(f.Name),
},
AsName: f.AsName,
}
}
return fields
}
tbl := []struct {
Fields []pair
Name string
Err bool
Index int
}{
{[]pair{{"id", ""}}, "id", false, 0},
{[]pair{{"id", "a"}, {"name", "a"}}, "a", true, -1},
{[]pair{{"id", "a"}, {"name", "a"}}, "id", false, -1},
}
for _, t := range tbl {
rs := createResultFields([]string{"id", "name"})
fs := createFields(t.Fields)
sl, err := plans.ResolveSelectList(fs, rs)
c.Assert(err, IsNil)
index, err := sl.CheckReferAmbiguous(&expression.Ident{
CIStr: model.NewCIStr(t.Name),
})
if t.Err {
c.Assert(err, NotNil)
continue
}
c.Assert(err, IsNil)
c.Assert(t.Index, Equals, index)
}
}

View File

@ -132,7 +132,9 @@ func (s *ShowPlan) Filter(ctx context.Context, expr expression.Expression) (plan
// Next implements plan.Plan Next interface.
func (s *ShowPlan) Next(ctx context.Context) (row *plan.Row, err error) {
if s.rows == nil {
s.fetchAll(ctx)
if err := s.fetchAll(ctx); err != nil {
return nil, errors.Trace(err)
}
}
if s.cursor == len(s.rows) {
return
@ -192,12 +194,13 @@ func (s *ShowPlan) getTable(ctx context.Context) (table.Table, error) {
is := sessionctx.GetDomain(ctx).InfoSchema()
dbName := model.NewCIStr(s.DBName)
if !is.SchemaExists(dbName) {
return nil, errors.Errorf("Can not find DB: %s", dbName)
// MySQL returns no such table here if database doesn't exist.
return nil, errors.Trace(mysql.NewErr(mysql.ErrNoSuchTable, s.DBName, s.TableName))
}
tbName := model.NewCIStr(s.TableName)
tb, err := is.TableByName(dbName, tbName)
if err != nil {
return nil, errors.Errorf("Can not find table: %s", s.TableName)
return nil, errors.Trace(mysql.NewErr(mysql.ErrNoSuchTable, s.DBName, s.TableName))
}
return tb, nil
}

View File

@ -224,4 +224,9 @@ func (p *testShowSuit) TestShowTables(c *C) {
c.Assert(cnt, Equals, 2)
cnt = mustQuery(c, testDB, `show full tables where Table_type != 'VIEW';`)
c.Assert(cnt, Equals, 3)
mustQuery(c, testDB, `show create table tab00;`)
rows, _ := testDB.Query(`show create table abc;`)
rows.Next()
c.Assert(rows.Err(), NotNil)
}

View File

@ -39,32 +39,9 @@ type GroupByRset struct {
SelectList *plans.SelectList
}
// HasAmbiguousField checks whether have ambiguous group by fields.
func (r *GroupByRset) HasAmbiguousField(indices []int, fields []*field.Field) bool {
columnNameMap := map[string]struct{}{}
for _, index := range indices {
expr := fields[index].Expr
// `select c1 + c2 as c1, c1 + c3 as c1 from t order by c1` is valid,
// if it is not `Ident` expression, ignore it.
v, ok := expr.(*expression.Ident)
if !ok {
continue
}
// `select c1 as c2, c1 as c2 from t order by c2` is valid,
// use a map for it here.
columnNameMap[v.L] = struct{}{}
}
return len(columnNameMap) > 1
}
// Plan gets GroupByDefaultPlan.
func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) {
fields := r.SelectList.Fields
resultfields := r.SelectList.ResultFields
srcFields := r.Src.GetFields()
r.SelectList.AggFields = GetAggFields(fields)
aggFields := r.SelectList.AggFields
@ -93,30 +70,25 @@ func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) {
// use Position expression for the associated field.
r.By[i] = &expression.Position{N: position}
} else {
index, err := r.SelectList.CheckReferAmbiguous(e)
if err != nil {
return nil, errors.Errorf("Column '%s' in group statement is ambiguous", e)
} else if _, ok := aggFields[index]; ok {
return nil, errors.Errorf("Can't group on '%s'", e)
}
// TODO: check more ambiguous case
// Group by ambiguous rule:
// select c1 as a, c2 as a from t group by a is ambiguous
// select c1 as a, c2 as a from t group by a + 1 is ambiguous
// select c1 as c2, c2 from t group by c2 is ambiguous
// select c1 as c2, c2 from t group by c2 + 1 is ambiguous
// TODO: use visitor to check aggregate function
names := expression.MentionedColumns(e)
for _, name := range names {
if field.ContainFieldName(name, srcFields, field.DefaultFieldFlag) {
// check whether column is qualified, like `select t.c1 c1, t.c2 from t group by t.c1, t.c2`
// no need to check ambiguous field.
if expression.IsQualified(name) {
continue
}
// check ambiguous fields, like `select c1 as c2, c2 from t group by c2`.
if err := field.CheckAmbiguousField(name, resultfields, field.DefaultFieldFlag); err == nil {
continue
}
}
// check reference to group function name
indices := field.GetFieldIndex(name, fields[0:r.SelectList.HiddenFieldOffset], field.CheckFieldFlag)
if len(indices) > 1 {
// check ambiguous fields, like `select c1 as a, c2 as a from t group by a`,
// notice that `select c2 as c2, c2 as c2 from t group by c2;` is valid.
if r.HasAmbiguousField(indices, fields[0:r.SelectList.HiddenFieldOffset]) {
return nil, errors.Errorf("Column '%s' in group statement is ambiguous", name)
}
} else if len(indices) == 1 {
indices := field.GetFieldIndex(name, fields[0:r.SelectList.HiddenFieldOffset], field.DefaultFieldFlag)
if len(indices) == 1 {
// check reference to aggregate function, like `select c1, count(c1) as b from t group by b + 1`.
index := indices[0]
if _, ok := aggFields[index]; ok {

View File

@ -125,48 +125,6 @@ func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) {
c.Assert(err, NotNil)
}
func (s *testGroupByRsetSuite) TestGroupByHasAmbiguousField(c *C) {
fld := &field.Field{Expr: expression.Value{Val: 1}}
// check `1`
fields := []*field.Field{fld}
indices := []int{0}
ret := s.r.HasAmbiguousField(indices, fields)
c.Assert(ret, IsFalse)
// check `c1 as c2, c1 as c2`
fld = &field.Field{Expr: &expression.Ident{CIStr: model.NewCIStr("c1")}, AsName: "c2"}
fields = []*field.Field{fld, fld}
indices = []int{0, 1}
ret = s.r.HasAmbiguousField(indices, fields)
c.Assert(ret, IsFalse)
// check `c1+c2 as c2, c1+c3 as c2`
exprx := expression.NewBinaryOperation(opcode.Plus, expression.Value{Val: "c1"},
expression.Value{Val: "c2"})
expry := expression.NewBinaryOperation(opcode.Plus, expression.Value{Val: "c1"},
expression.Value{Val: "c3"})
fldx := &field.Field{Expr: exprx, AsName: "c2"}
fldy := &field.Field{Expr: expry, AsName: "c2"}
fields = []*field.Field{fldx, fldy}
indices = []int{0, 1}
ret = s.r.HasAmbiguousField(indices, fields)
c.Assert(ret, IsFalse)
// check `c1 as c2, c3 as c2`
fldx = &field.Field{Expr: &expression.Ident{CIStr: model.NewCIStr("c1")}, AsName: "c2"}
fldy = &field.Field{Expr: &expression.Ident{CIStr: model.NewCIStr("c3")}, AsName: "c2"}
fields = []*field.Field{fldx, fldy}
indices = []int{0, 1}
ret = s.r.HasAmbiguousField(indices, fields)
c.Assert(ret, IsTrue)
}
func (s *testGroupByRsetSuite) TestGroupByRsetString(c *C) {
str := s.r.String()
c.Assert(len(str), Greater, 0)

View File

@ -28,8 +28,9 @@ var (
// HavingRset is record set for having fields.
type HavingRset struct {
Src plan.Plan
Expr expression.Expression
Src plan.Plan
Expr expression.Expression
SelectList *plans.SelectList
}
// CheckAndUpdateSelectList checks having fields validity and set hidden fields to selectList.
@ -90,7 +91,7 @@ func (r *HavingRset) CheckAndUpdateSelectList(selectList *plans.SelectList, grou
// Plan gets HavingPlan.
func (r *HavingRset) Plan(ctx context.Context) (plan.Plan, error) {
return &plans.HavingPlan{Src: r.Src, Expr: r.Expr}, nil
return &plans.HavingPlan{Src: r.Src, Expr: r.Expr, SelectList: r.SelectList}, nil
}
// String implements fmt.Stringer interface.

View File

@ -35,12 +35,7 @@ func (s *testHavingRsetSuite) SetUpSuite(c *C) {
// expr `id > 1`
expr := expression.NewBinaryOperation(opcode.GT, &expression.Ident{CIStr: model.NewCIStr("id")}, expression.Value{Val: 1})
s.r = &HavingRset{Src: tblPlan, Expr: expr}
}
func (s *testHavingRsetSuite) TestHavingRsetCheckAndUpdateSelectList(c *C) {
resultFields := s.r.Src.GetFields()
resultFields := tblPlan.GetFields()
fields := make([]*field.Field, len(resultFields))
for i, resultField := range resultFields {
name := resultField.Name
@ -52,6 +47,13 @@ func (s *testHavingRsetSuite) TestHavingRsetCheckAndUpdateSelectList(c *C) {
ResultFields: resultFields,
Fields: fields,
}
s.r = &HavingRset{Src: tblPlan, Expr: expr, SelectList: selectList}
}
func (s *testHavingRsetSuite) TestHavingRsetCheckAndUpdateSelectList(c *C) {
resultFields := s.r.Src.GetFields()
selectList := s.r.SelectList
groupBy := []expression.Expression{}

View File

@ -76,16 +76,24 @@ func (r *OrderByRset) CheckAndUpdateSelectList(selectList *plans.SelectList, tab
r.By[i].Expr = expr
} else {
if _, err := selectList.CheckReferAmbiguous(v.Expr); err != nil {
return errors.Errorf("Column '%s' in order statement is ambiguous", v.Expr)
}
// TODO: check more ambiguous case
// Order by ambiguous rule:
// select c1 as a, c2 as a from t order by a is ambiguous
// select c1 as a, c2 as a from t order by a + 1 is ambiguous
// select c1 as c2, c2 from t order by c2 is ambiguous
// select c1 as c2, c2 from t order by c2 + 1 is ambiguous
// TODO: use vistor to refactor all and combine following plan check.
names := expression.MentionedColumns(v.Expr)
for _, name := range names {
// try to find in select list
// TODO: mysql has confused result for this, see #555.
// now we use select list then order by, later we should make it easier.
if field.ContainFieldName(name, selectList.ResultFields, field.CheckFieldFlag) {
// check ambiguous fields, like `select c1 as c2, c2 from t order by c2`.
if err := field.CheckAmbiguousField(name, selectList.ResultFields, field.DefaultFieldFlag); err != nil {
return errors.Errorf("Column '%s' in order statement is ambiguous", name)
}
continue
}
@ -140,9 +148,16 @@ func (r *OrderByRset) Plan(ctx context.Context) (plan.Plan, error) {
r.By[i].Expr = &expression.Position{N: position}
}
} else {
// Don't check ambiguous here, only check field exists or not.
// TODO: use visitor to refactor.
colNames := expression.MentionedColumns(e)
if err := field.CheckAllFieldNames(colNames, fields, field.CheckFieldFlag); err != nil {
return nil, errors.Trace(err)
for _, name := range colNames {
if idx := field.GetResultFieldIndex(name, r.SelectList.ResultFields, field.DefaultFieldFlag); len(idx) == 0 {
// find in from
if idx = field.GetResultFieldIndex(name, r.SelectList.FromFields, field.DefaultFieldFlag); len(idx) == 0 {
return nil, errors.Errorf("unknown field %s", name)
}
}
}
}

View File

@ -18,6 +18,7 @@
package rsets
import (
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
@ -188,13 +189,18 @@ func (r *WhereRset) Plan(ctx context.Context) (plan.Plan, error) {
return r.planStatic(ctx, expr)
}
var (
src = r.Src
err error
)
switch x := expr.(type) {
case *expression.BinaryOperation:
return r.planBinOp(ctx, x)
src, err = r.planBinOp(ctx, x)
case *expression.Ident:
return r.planIdent(ctx, x)
src, err = r.planIdent(ctx, x)
case *expression.IsNull:
return r.planIsNull(ctx, x)
src, err = r.planIsNull(ctx, x)
case *expression.PatternIn:
// TODO: optimize
// TODO: show plan
@ -203,10 +209,22 @@ func (r *WhereRset) Plan(ctx context.Context) (plan.Plan, error) {
case *expression.PatternRegexp:
// TODO: optimize
case *expression.UnaryOperation:
return r.planUnaryOp(ctx, x)
src, err = r.planUnaryOp(ctx, x)
default:
log.Warnf("%v not supported in where rset now", r.Expr)
}
return &plans.FilterDefaultPlan{Plan: r.Src, Expr: expr}, nil
if err != nil {
return nil, errors.Trace(err)
}
if _, ok := src.(*plans.FilterDefaultPlan); ok {
return src, nil
}
// We must use a FilterDefaultPlan here to wrap filtered plan.
// Alghough we can check where condition using index plan, we still need
// to check again after the FROM phase if the FROM phase contains outer join.
// TODO: if FROM phase doesn't contain outer join, we can return filtered plan directly.
return &plans.FilterDefaultPlan{Plan: src, Expr: expr}, nil
}

View File

@ -28,6 +28,7 @@ import (
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/kv"
mysql "github.com/pingcap/tidb/mysqldef"
@ -40,6 +41,7 @@ import (
"github.com/pingcap/tidb/stmt/stmts"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/errors2"
"github.com/pingcap/tidb/util/sqlexec"
)
// Session context
@ -228,6 +230,34 @@ func (s *session) Retry() error {
return nil
}
// ExecRestrictedSQL implements SQLHelper interface.
// This is used for executing some restricted sql statements.
func (s *session) ExecRestrictedSQL(ctx context.Context, sql string) (rset.Recordset, error) {
if ctx.Value(&sqlexec.RestrictedSQLExecutorKeyType{}) != nil {
// We do not support run this function concurrently.
// TODO: Maybe we should remove this restriction latter.
return nil, errors.New("Should not call ExecRestrictedSQL concurrently.")
}
statements, err := Compile(sql)
if err != nil {
log.Errorf("Compile %s with error: %v", sql, err)
return nil, errors.Trace(err)
}
if len(statements) != 1 {
log.Errorf("ExecRestrictedSQL only executes one statement. Too many/few statement in %s", sql)
return nil, errors.New("Wrong number of statement.")
}
st := statements[0]
// Check statement for some restriction
// For example only support DML on system meta table.
// TODO: Add more restrictions.
log.Infof("Executing %s [%s]", st, sql)
ctx.SetValue(&sqlexec.RestrictedSQLExecutorKeyType{}, true)
defer ctx.ClearValue(&sqlexec.RestrictedSQLExecutorKeyType{})
rs, err := st.Exec(ctx)
return rs, errors.Trace(err)
}
func (s *session) Execute(sql string) ([]rset.Recordset, error) {
statements, err := Compile(sql)
if err != nil {

View File

@ -18,21 +18,18 @@
package stmts
import (
"fmt"
"strings"
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/parser/coldef"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/rset"
"github.com/pingcap/tidb/rset/rsets"
"github.com/pingcap/tidb/stmt"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/format"
"github.com/pingcap/tidb/util/sqlexec"
)
/************************************************************************************
@ -73,39 +70,13 @@ func (s *CreateUserStmt) SetText(text string) {
s.Text = text
}
func composeUserTableFilter(name string, host string) expression.Expression {
nameMatch := expression.NewBinaryOperation(opcode.EQ, &expression.Ident{CIStr: model.NewCIStr("User")}, &expression.Value{Val: name})
hostMatch := expression.NewBinaryOperation(opcode.EQ, &expression.Ident{CIStr: model.NewCIStr("Host")}, &expression.Value{Val: host})
return expression.NewBinaryOperation(opcode.AndAnd, nameMatch, hostMatch)
}
func composeUserTableRset() *rsets.JoinRset {
return &rsets.JoinRset{
Left: &rsets.TableSource{
Source: table.Ident{
Name: model.NewCIStr(mysql.UserTable),
Schema: model.NewCIStr(mysql.SystemDB),
},
},
}
}
func (s *CreateUserStmt) userExists(ctx context.Context, name string, host string) (bool, error) {
r := composeUserTableRset()
p, err := r.Plan(ctx)
sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND Host="%s";`, mysql.SystemDB, mysql.UserTable, name, host)
rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
if err != nil {
return false, errors.Trace(err)
}
where := &rsets.WhereRset{
Src: p,
Expr: composeUserTableFilter(name, host),
}
p, err = where.Plan(ctx)
if err != nil {
return false, errors.Trace(err)
}
defer p.Close()
row, err := p.Next(ctx)
row, err := rs.Next()
if err != nil {
return false, errors.Trace(err)
}
@ -114,14 +85,7 @@ func (s *CreateUserStmt) userExists(ctx context.Context, name string, host strin
// Exec implements the stmt.Statement Exec interface.
func (s *CreateUserStmt) Exec(ctx context.Context) (rset.Recordset, error) {
st := &InsertIntoStmt{
TableIdent: table.Ident{
Name: model.NewCIStr(mysql.UserTable),
Schema: model.NewCIStr(mysql.SystemDB),
},
ColNames: []string{"Host", "User", "Password"},
}
values := make([][]expression.Expression, 0, len(s.Specs))
users := make([]string, 0, len(s.Specs))
for _, spec := range s.Specs {
strs := strings.Split(spec.User, "@")
userName := strs[0]
@ -136,22 +100,20 @@ func (s *CreateUserStmt) Exec(ctx context.Context) (rset.Recordset, error) {
}
continue
}
value := make([]expression.Expression, 0, 3)
value = append(value, expression.Value{Val: host})
value = append(value, expression.Value{Val: userName})
pwd := ""
if spec.AuthOpt.ByAuthString {
value = append(value, expression.Value{Val: util.EncodePassword(spec.AuthOpt.AuthString)})
pwd = util.EncodePassword(spec.AuthOpt.AuthString)
} else {
// TODO: Maybe we should hash the string here?
value = append(value, expression.Value{Val: util.EncodePassword(spec.AuthOpt.HashString)})
pwd = util.EncodePassword(spec.AuthOpt.HashString)
}
values = append(values, value)
user := fmt.Sprintf(`("%s", "%s", "%s")`, host, userName, pwd)
users = append(users, user)
}
if len(values) == 0 {
if len(users) == 0 {
return nil, nil
}
st.Lists = values
_, err := st.Exec(ctx)
sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", "))
_, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
if err != nil {
return nil, errors.Trace(err)
}
@ -188,20 +150,13 @@ func (s *SetPwdStmt) SetText(text string) {
}
// Exec implements the stmt.Statement Exec interface.
func (s *SetPwdStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
// If len(s.User) == 0, use CURRENT_USER()
func (s *SetPwdStmt) Exec(ctx context.Context) (rset.Recordset, error) {
// TODO: If len(s.User) == 0, use CURRENT_USER()
strs := strings.Split(s.User, "@")
userName := strs[0]
host := strs[1]
// Update mysql.user
asgn := expression.Assignment{
ColName: "Password",
Expr: expression.Value{Val: util.EncodePassword(s.Password)},
}
st := &UpdateStmt{
TableRefs: composeUserTableRset(),
List: []expression.Assignment{asgn},
Where: composeUserTableFilter(userName, host),
}
return st.Exec(ctx)
sql := fmt.Sprintf(`UPDATE %s.%s SET password="%s" WHERE User="%s" AND Host="%s";`, mysql.SystemDB, mysql.UserTable, util.EncodePassword(s.Password), userName, host)
_, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
return nil, errors.Trace(err)
}

View File

@ -127,23 +127,27 @@ func (s *DeleteStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
defer p.Close()
tblIDMap := make(map[int64]bool, len(s.TableIdents))
// Get table alias map.
tblAliasMap := make(map[string]string)
tblNames := make(map[string]string)
if s.MultiTable {
// Delete from multiple tables should consider table ident list.
fs := p.GetFields()
for _, f := range fs {
if f.TableName != f.OrgTableName {
tblAliasMap[f.TableName] = f.OrgTableName
tblNames[f.TableName] = f.OrgTableName
} else {
tblNames[f.TableName] = f.TableName
}
}
for _, t := range s.TableIdents {
// Consider DBName.
oname, ok := tblAliasMap[t.Name.O]
if ok {
t.Name.O = oname
t.Name.L = strings.ToLower(oname)
oname, ok := tblNames[t.Name.O]
if !ok {
return nil, errors.Errorf("Unknown table '%s' in MULTI DELETE", t.Name.O)
}
t.Name.O = oname
t.Name.L = strings.ToLower(oname)
var tbl table.Table
tbl, err = getTable(ctx, t)
if err != nil {

View File

@ -199,5 +199,8 @@ func (s *testStmtSuite) TestQualifedDelete(c *C) {
r = mustExec(c, s.testDB, "delete a, b from t1 as a join t2 as b where a.c2 = b.c1")
checkResult(c, r, 2, 0)
_, err = s.testDB.Exec("delete t1, t2 from t1 as a join t2 as b where a.c2 = b.c1")
c.Assert(err, NotNil)
mustExec(c, s.testDB, "drop table t1, t2")
}

View File

@ -22,6 +22,7 @@ import (
"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/kv"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/plan"
@ -232,6 +233,12 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
}
insertValueCount := len(s.Lists[0])
toUpdateColumns, err0 := getOnDuplicateUpdateColumns(s.OnDuplicate, t)
if err0 != nil {
return nil, errors.Trace(err0)
}
toUpdateArgs := map[interface{}]interface{}{}
for i, list := range s.Lists {
r := make([]interface{}, len(tableCols))
valueCount := len(list)
@ -294,15 +301,20 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
// On duplicate key Update the duplicate row.
// Evaluate the updated value.
// TODO: report rows affected and last insert id.
toUpdateColumns, _, err := getUpdateColumns(t, s.OnDuplicate, false, nil)
if err != nil {
return nil, errors.Trace(err)
}
data, err := t.Row(ctx, h)
if err != nil {
return nil, errors.Trace(err)
}
err = updateRecord(ctx, h, data, t, toUpdateColumns, s.OnDuplicate, r, nil)
toUpdateArgs[expression.ExprEvalValuesFunc] = func(name string) (interface{}, error) {
c, err1 := findColumnByName(t, name)
if err1 != nil {
return nil, errors.Trace(err1)
}
return r[c.Offset], nil
}
err = updateRecord(ctx, h, data, t, toUpdateColumns, toUpdateArgs, 0, true)
if err != nil {
return nil, errors.Trace(err)
}
@ -311,6 +323,19 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
return nil, nil
}
func getOnDuplicateUpdateColumns(assignList []expression.Assignment, t table.Table) (map[int]expression.Assignment, error) {
m := make(map[int]expression.Assignment, len(assignList))
for _, v := range assignList {
c, err := findColumnByName(t, field.JoinQualifiedName("", v.TableName, v.ColName))
if err != nil {
return nil, errors.Trace(err)
}
m[c.Offset] = v
}
return m, nil
}
func (s *InsertIntoStmt) initDefaultValues(ctx context.Context, t table.Table, cols []*column.Col, row []interface{}, marked map[int]struct{}) error {
var err error
var defaultValueCols []*column.Col

View File

@ -89,4 +89,10 @@ func (s *testStmtSuite) TestInsert(c *C) {
insertSQL := `insert into insert_test (id, c2) values (1, 1) on duplicate key update c2=10;`
mustExec(c, s.testDB, insertSQL)
insertSQL = `insert into insert_test (id, c2) values (1, 1) on duplicate key update insert_test.c2=10;`
mustExec(c, s.testDB, insertSQL)
_, err = s.testDB.Exec(`insert into insert_test (id, c2) values(1, 1) on duplicate key update t.c2 = 10`)
c.Assert(err, NotNil)
}

View File

@ -207,8 +207,9 @@ func (s *SelectStmt) Plan(ctx context.Context) (plan.Plan, error) {
if s := s.Having; s != nil {
if r, err = (&rsets.HavingRset{
Src: r,
Expr: s.Expr}).Plan(ctx); err != nil {
Src: r,
Expr: s.Expr,
SelectList: selectList}).Plan(ctx); err != nil {
return nil, err
}
}

View File

@ -18,7 +18,7 @@
package stmts
import (
"strings"
"fmt"
"github.com/juju/errors"
"github.com/ngaut/log"
@ -84,47 +84,45 @@ func (s *UpdateStmt) SetText(text string) {
s.Text = text
}
func getUpdateColumns(t table.Table, assignList []expression.Assignment, isMultipleTable bool, tblAliasMap map[string]string) ([]*column.Col, []expression.Assignment, error) {
// TODO: We should check the validate if assignList in somewhere else. Maybe in building plan.
// TODO: We should use field.GetFieldIndex to replace this function.
tcols := make([]*column.Col, 0, len(assignList))
tAsgns := make([]expression.Assignment, 0, len(assignList))
tname := t.TableName()
for _, asgn := range assignList {
if isMultipleTable {
if tblAliasMap != nil {
if alias, ok := tblAliasMap[asgn.TableName]; ok {
if !strings.EqualFold(tname.O, alias) {
continue
}
}
} else if !strings.EqualFold(tname.O, asgn.TableName) {
continue
}
}
col := column.FindCol(t.Cols(), asgn.ColName)
if col == nil {
if isMultipleTable {
continue
}
return nil, nil, errors.Errorf("UPDATE: unknown column %s", asgn.ColName)
}
tcols = append(tcols, col)
tAsgns = append(tAsgns, asgn)
func findColumnByName(t table.Table, name string) (*column.Col, error) {
_, tableName, colName := field.SplitQualifiedName(name)
if len(tableName) > 0 && tableName != t.TableName().O {
return nil, errors.Errorf("unknown field %s.%s", tableName, colName)
}
return tcols, tAsgns, nil
c := column.FindCol(t.Cols(), colName)
if c == nil {
return nil, errors.Errorf("unknown field %s", colName)
}
return c, nil
}
func getInsertValue(name string, cols []*column.Col, row []interface{}) (interface{}, error) {
for i, col := range cols {
if col.Name.L == name {
return row[i], nil
func getUpdateColumns(assignList []expression.Assignment, fields []*field.ResultField) (map[int]expression.Assignment, error) {
m := make(map[int]expression.Assignment, len(assignList))
for _, v := range assignList {
name := v.ColName
if len(v.TableName) > 0 {
name = fmt.Sprintf("%s.%s", v.TableName, v.ColName)
}
// use result fields to check assign list, otherwise use origin table columns
idx := field.GetResultFieldIndex(name, fields, field.DefaultFieldFlag)
if n := len(idx); n > 1 {
return nil, errors.Errorf("ambiguous field %s", name)
} else if n == 0 {
return nil, errors.Errorf("unknown field %s", name)
}
m[idx[0]] = v
}
return nil, errors.Errorf("unknown field %s", name)
return m, nil
}
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table, tcols []*column.Col, assignList []expression.Assignment, insertData []interface{}, args map[interface{}]interface{}) error {
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table,
updateColumns map[int]expression.Assignment, m map[interface{}]interface{},
offset int, onDuplicateUpdate bool) error {
if err := t.LockRow(ctx, h, true); err != nil {
return errors.Trace(err)
}
@ -133,29 +131,27 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl
touched := make([]bool, len(t.Cols()))
copy(oldData, data)
// Generate new values
m := args
if m == nil {
m = make(map[interface{}]interface{}, len(t.Cols()))
// Set parameter for evaluating expression.
for _, col := range t.Cols() {
m[col.Name.L] = data[col.Offset]
}
}
if insertData != nil {
m[expression.ExprEvalValuesFunc] = func(name string) (interface{}, error) {
return getInsertValue(name, t.Cols(), insertData)
}
}
cols := t.Cols()
for i, asgn := range assignList {
assignExists := false
for i, asgn := range updateColumns {
if i < offset || i >= offset+len(cols) {
// The assign expression is for another table, not this.
continue
}
val, err := asgn.Expr.Eval(ctx, m)
if err != nil {
return err
}
colIndex := tcols[i].Offset
colIndex := i - offset
touched[colIndex] = true
data[colIndex] = val
assignExists = true
}
// no assign list for this table, no need to update.
if !assignExists {
return nil
}
// Check whether new value is valid.
@ -198,7 +194,7 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl
return errors.Trace(err)
}
// Record affected rows.
if len(insertData) == 0 {
if !onDuplicateUpdate {
variable.GetSessionVars(ctx).AddAffectedRows(1)
} else {
variable.GetSessionVars(ctx).AddAffectedRows(2)
@ -249,17 +245,17 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
}
defer p.Close()
updatedRowKeys := make(map[string]bool)
// For single-table syntax, TableRef may contain multiple tables
isMultipleTable := s.MultipleTable || s.TableRefs.MultipleTable()
// Get table alias map.
fs := p.GetFields()
tblAliasMap := make(map[string]string)
for _, f := range fs {
if f.TableName != f.OrgTableName {
tblAliasMap[f.TableName] = f.OrgTableName
}
columns, err0 := getUpdateColumns(s.List, fs)
if err0 != nil {
return nil, errors.Trace(err0)
}
m := map[interface{}]interface{}{}
var records []*plan.Row
for {
row, err1 := p.Next(ctx)
if err1 != nil {
@ -268,21 +264,30 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
if row == nil {
break
}
rowData := row.Data
if len(row.RowKeys) == 0 {
// Nothing to update
return nil, nil
continue
}
records = append(records, row)
}
for _, row := range records {
rowData := row.Data
// Set EvalIdentFunc
m := make(map[interface{}]interface{})
m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) {
return plans.GetIdentValue(name, p.GetFields(), rowData, field.DefaultFieldFlag)
}
// Update rows
start := 0
offset := 0
for _, entry := range row.RowKeys {
tbl := entry.Tbl
k := entry.Key
lastOffset := offset
offset += len(tbl.Cols())
data := rowData[lastOffset:offset]
_, ok := updatedRowKeys[k]
if ok {
// Each matching row is updated once, even if it matches the conditions multiple times.
@ -293,23 +298,12 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
if err2 != nil {
return nil, errors.Trace(err2)
}
end := start + len(tbl.Cols())
data := rowData[start:end]
start = end
// For multiple table mode, get to-update cols and to-update assginments.
tcols, tAsgns, err2 := getUpdateColumns(tbl, s.List, isMultipleTable, tblAliasMap)
if err2 != nil {
return nil, errors.Trace(err2)
}
if len(tcols) == 0 {
// Nothing to update for this table.
continue
}
// Get data in the table
err2 = updateRecord(ctx, handle, data, tbl, tcols, tAsgns, nil, m)
err2 = updateRecord(ctx, handle, data, tbl, columns, m, lastOffset, false)
if err2 != nil {
return nil, errors.Trace(err2)
}
updatedRowKeys[k] = true
}
}

View File

@ -179,11 +179,15 @@ func (s *testStmtSuite) TestMultipleTableUpdate(c *C) {
func (s *testStmtSuite) TestIssue345(c *C) {
testDB, err := sql.Open(tidb.DriverName, tidb.EngineGoLevelDBMemory+"tmp-issue345/"+s.dbName)
c.Assert(err, IsNil)
mustExec(c, testDB, `drop table if exists t1, t2`)
mustExec(c, testDB, `create table t1 (c1 int);`)
mustExec(c, testDB, `create table t2 (c2 int);`)
mustExec(c, testDB, `insert into t1 values (1);`)
mustExec(c, testDB, `insert into t2 values (2);`)
mustExec(c, testDB, `update t1, t2 set t1.c1 = 2, t2.c2 = 1;`)
mustExec(c, testDB, `update t1, t2 set c1 = 2, c2 = 1;`)
mustExec(c, testDB, `update t1 as a, t2 as b set a.c1 = 2, b.c2 = 1;`)
// Check t1 content
tx := mustBegin(c, testDB)
rows, err := tx.Query("SELECT * FROM t1;")
@ -212,5 +216,34 @@ func (s *testStmtSuite) TestIssue345(c *C) {
c.Assert(err, IsNil)
matchRows(c, rows, [][]interface{}{{2}})
rows.Close()
_, err = testDB.Exec(`update t1 as a, t2 set t1.c1 = 10;`)
c.Assert(err, NotNil)
mustCommit(c, tx)
}
func (s *testStmtSuite) TestMultiUpdate(c *C) {
// fix https://github.com/pingcap/tidb/issues/369
testSQL := `
DROP TABLE IF EXISTS t1, t2;
create table t1 (c int);
create table t2 (c varchar(256));
insert into t1 values (1), (2);
insert into t2 values ("a"), ("b");
update t1, t2 set t1.c = 10, t2.c = "abc";`
mustExec(c, s.testDB, testSQL)
// fix https://github.com/pingcap/tidb/issues/376
testSQL = `DROP TABLE IF EXISTS t1, t2;
create table t1 (c1 int);
create table t2 (c2 int);
insert into t1 values (1), (2);
insert into t2 values (1), (2);
update t1, t2 set t1.c1 = 10, t2.c2 = 2 where t2.c2 = 1;`
mustExec(c, s.testDB, testSQL)
rows, err := s.testDB.Query("select * from t1")
c.Assert(err, IsNil)
matchRows(c, rows, [][]interface{}{{10}, {10}})
}

View File

@ -103,7 +103,6 @@ func (s *dbStore) GetSnapshot() (kv.MvccSnapshot, error) {
if err != nil {
return nil, errors.Trace(err)
}
// dbSnapshot implements MvccSnapshot interface.
return &dbSnapshot{
db: s.db,
version: currentVer,
@ -121,13 +120,13 @@ func (s *dbStore) Begin() (kv.Transaction, error) {
}
txn := &dbTxn{
startTs: time.Now(),
tID: beginVer.Ver,
tid: beginVer.Ver,
valid: true,
store: s,
version: kv.MinVersion,
snapshotVals: make(map[string][]byte),
}
log.Debugf("Begin txn:%d", txn.tID)
log.Debugf("Begin txn:%d", txn.tid)
txn.UnionStore, err = kv.NewUnionStore(&dbSnapshot{
db: s.db,
version: beginVer,
@ -164,7 +163,7 @@ func (s *dbStore) newBatch() engine.Batch {
}
// Both lock and unlock are used for simulating scenario of percolator papers.
func (s *dbStore) tryConditionLockKey(tID uint64, key string, snapshotVal []byte) error {
func (s *dbStore) tryConditionLockKey(tid uint64, key string, snapshotVal []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
@ -187,12 +186,12 @@ func (s *dbStore) tryConditionLockKey(tID uint64, key string, snapshotVal []byte
}
// If there's newer version of this key, returns error.
if ver > tID {
log.Warnf("txn:%d, tryLockKey condition not match for key %s, currValue:%q, snapshotVal:%q", tID, key, currValue, snapshotVal)
if ver > tid {
log.Warnf("txn:%d, tryLockKey condition not match for key %s, currValue:%q, snapshotVal:%q", tid, key, currValue, snapshotVal)
return errors.Trace(kv.ErrConditionNotMatch)
}
s.keysLocked[key] = tID
s.keysLocked[key] = tid
return nil
}

View File

@ -28,6 +28,7 @@ const (
func (l *LocalVersionProvider) CurrentVersion() (kv.Version, error) {
l.mu.Lock()
defer l.mu.Unlock()
var ts uint64
ts = uint64((time.Now().UnixNano() / int64(time.Millisecond)) << timePrecisionOffset)
if l.lastTimeStampTs == uint64(ts) {

View File

@ -29,6 +29,7 @@ var (
_ kv.Iterator = (*dbIter)(nil)
)
// dbSnapshot implements MvccSnapshot interface.
type dbSnapshot struct {
db engine.DB
rawIt engine.Iterator

View File

@ -39,7 +39,7 @@ type dbTxn struct {
kv.UnionStore
store *dbStore // for commit
startTs time.Time
tID uint64
tid uint64
valid bool
version kv.Version // commit version
snapshotVals map[string][]byte // origin version in snapshot
@ -66,7 +66,7 @@ func (txn *dbTxn) markOrigin(k []byte) error {
// Implement transaction interface
func (txn *dbTxn) Inc(k kv.Key, step int64) (int64, error) {
log.Debugf("Inc %q, step %d txn:%d", k, step, txn.tID)
log.Debugf("Inc %q, step %d txn:%d", k, step, txn.tid)
k = kv.EncodeKey(k)
if err := txn.markOrigin(k); err != nil {
@ -113,11 +113,11 @@ func (txn *dbTxn) GetInt64(k kv.Key) (int64, error) {
}
func (txn *dbTxn) String() string {
return fmt.Sprintf("%d", txn.tID)
return fmt.Sprintf("%d", txn.tid)
}
func (txn *dbTxn) Get(k kv.Key) ([]byte, error) {
log.Debugf("get key:%q, txn:%d", k, txn.tID)
log.Debugf("get key:%q, txn:%d", k, txn.tid)
k = kv.EncodeKey(k)
val, err := txn.UnionStore.Get(k)
if kv.IsErrNotFound(err) {
@ -140,7 +140,7 @@ func (txn *dbTxn) Set(k kv.Key, data []byte) error {
return errors.Trace(ErrCannotSetNilValue)
}
log.Debugf("set key:%q, txn:%d", k, txn.tID)
log.Debugf("set key:%q, txn:%d", k, txn.tid)
k = kv.EncodeKey(k)
err := txn.UnionStore.Set(k, data)
if err != nil {
@ -151,7 +151,7 @@ func (txn *dbTxn) Set(k kv.Key, data []byte) error {
}
func (txn *dbTxn) Seek(k kv.Key, fnKeyCmp func(kv.Key) bool) (kv.Iterator, error) {
log.Debugf("seek key:%q, txn:%d", k, txn.tID)
log.Debugf("seek key:%q, txn:%d", k, txn.tid)
k = kv.EncodeKey(k)
iter, err := txn.UnionStore.Seek(k, txn)
@ -172,7 +172,7 @@ func (txn *dbTxn) Seek(k kv.Key, fnKeyCmp func(kv.Key) bool) (kv.Iterator, error
}
func (txn *dbTxn) Delete(k kv.Key) error {
log.Debugf("delete key:%q, txn:%d", k, txn.tID)
log.Debugf("delete key:%q, txn:%d", k, txn.tid)
k = kv.EncodeKey(k)
return txn.UnionStore.Delete(k)
}
@ -198,7 +198,7 @@ func (txn *dbTxn) doCommit() error {
}()
// Check locked keys
for k, v := range txn.snapshotVals {
err := txn.store.tryConditionLockKey(txn.tID, k, v)
err := txn.store.tryConditionLockKey(txn.tid, k, v)
if err != nil {
return errors.Trace(err)
}
@ -236,7 +236,7 @@ func (txn *dbTxn) Commit() error {
if !txn.valid {
return errors.Trace(ErrInvalidTxn)
}
log.Infof("commit txn %d", txn.tID)
log.Infof("commit txn %d", txn.tid)
defer func() {
txn.close()
}()
@ -263,7 +263,7 @@ func (txn *dbTxn) Rollback() error {
if !txn.valid {
return errors.Trace(ErrInvalidTxn)
}
log.Warnf("Rollback txn %d", txn.tID)
log.Warnf("Rollback txn %d", txn.tid)
return txn.close()
}

View File

@ -201,11 +201,11 @@ func (cc *clientConn) readHandshakeResponse() error {
addr := cc.conn.RemoteAddr().String()
host, _, err1 := net.SplitHostPort(addr)
if err1 != nil {
return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, addr, "Yes"))
return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, addr, "Yes"))
}
user := fmt.Sprintf("%s@%s", cc.user, host)
if !cc.ctx.Auth(user, auth, cc.salt) {
return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, host, "Yes"))
return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, host, "Yes"))
}
}
return nil
@ -283,8 +283,7 @@ func (cc *clientConn) dispatch(data []byte) error {
case mysql.ComStmtReset:
return cc.handleStmtReset(data)
default:
msg := fmt.Sprintf("command %d not supported now", cmd)
return mysql.NewError(mysql.ErUnknownError, msg)
return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", cmd)
}
}
@ -322,8 +321,9 @@ func (cc *clientConn) writeOK() error {
func (cc *clientConn) writeError(e error) error {
var m *mysql.SQLError
var ok bool
if m, ok = e.(*mysql.SQLError); !ok {
m = mysql.NewError(mysql.ErUnknownError, e.Error())
originErr := errors.Cause(e)
if m, ok = originErr.(*mysql.SQLError); !ok {
m = mysql.NewErrf(mysql.ErrUnknown, e.Error())
}
data := make([]byte, 4, 16+len(m.Message))

View File

@ -36,7 +36,6 @@ package server
import (
"encoding/binary"
"fmt"
"math"
"strconv"
@ -113,7 +112,7 @@ func (cc *clientConn) handleStmtExecute(data []byte) (err error) {
stmt := cc.ctx.GetStatement(int(stmtID))
if stmt == nil {
return mysql.NewDefaultError(mysql.ErUnknownStmtHandler,
return mysql.NewErr(mysql.ErrUnknownStmtHandler,
strconv.FormatUint(uint64(stmtID), 10), "stmt_execute")
}
@ -121,7 +120,7 @@ func (cc *clientConn) handleStmtExecute(data []byte) (err error) {
pos++
//now we only support CURSOR_TYPE_NO_CURSOR flag
if flag != 0 {
return mysql.NewError(mysql.ErUnknownError, fmt.Sprintf("unsupported flag %d", flag))
return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag %d", flag)
}
//skip iteration-count, always 1
@ -324,7 +323,7 @@ func (cc *clientConn) handleStmtSendLongData(data []byte) (err error) {
stmt := cc.ctx.GetStatement(stmtID)
if stmt == nil {
return mysql.NewDefaultError(mysql.ErUnknownStmtHandler,
return mysql.NewErr(mysql.ErrUnknownStmtHandler,
strconv.Itoa(stmtID), "stmt_send_longdata")
}
@ -340,7 +339,7 @@ func (cc *clientConn) handleStmtReset(data []byte) (err error) {
stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
stmt := cc.ctx.GetStatement(stmtID)
if stmt == nil {
return mysql.NewDefaultError(mysql.ErUnknownStmtHandler,
return mysql.NewErr(mysql.ErrUnknownStmtHandler,
strconv.Itoa(stmtID), "stmt_reset")
}
stmt.Reset()

View File

@ -21,6 +21,7 @@ import (
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/rset"
"github.com/pingcap/tidb/util/errors2"
"github.com/pingcap/tidb/util/types"
)
// TiDBDriver implements IDriver.
@ -78,7 +79,7 @@ func (ts *TiDBStatement) Execute(args ...interface{}) (rs ResultSet, err error)
// AppendParam implements IStatement AppendParam method.
func (ts *TiDBStatement) AppendParam(paramID int, data []byte) error {
if paramID >= len(ts.boundParams) {
return mysql.NewDefaultError(mysql.ErWrongArguments, "stmt_send_longdata")
return mysql.NewErr(mysql.ErrWrongArguments, "stmt_send_longdata")
}
ts.boundParams[paramID] = append(ts.boundParams[paramID], data...)
return nil
@ -272,8 +273,16 @@ func convertColumnInfo(fld *field.ResultField) (ci *ColumnInfo) {
ci.Schema = fld.DBName
ci.Flag = uint16(fld.Flag)
ci.Charset = uint16(mysql.CharsetIDs[fld.Charset])
ci.ColumnLength = uint32(fld.Flen)
ci.Decimal = uint8(fld.Decimal)
if fld.Flen == types.UnspecifiedLength {
ci.ColumnLength = 0
} else {
ci.ColumnLength = uint32(fld.Flen)
}
if fld.Decimal == types.UnspecifiedLength {
ci.Decimal = 0
} else {
ci.Decimal = uint8(fld.Decimal)
}
ci.Type = uint8(fld.Tp)
return
}

20
tidb.go
View File

@ -26,11 +26,13 @@ import (
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/optimizer"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/rset"
"github.com/pingcap/tidb/sessionctx/variable"
@ -92,8 +94,20 @@ func Compile(src string) ([]stmt.Statement, error) {
log.Warnf("compiling %s, error: %v", src, l.Errors()[0])
return nil, errors.Trace(l.Errors()[0])
}
return l.Stmts(), nil
rawStmt := l.Stmts()
stmts := make([]stmt.Statement, len(rawStmt))
for i, v := range rawStmt {
if node, ok := v.(ast.Node); ok {
stm, err := optimizer.Compile(node)
if err != nil {
return nil, errors.Trace(err)
}
stmts[i] = stm
} else {
stmts[i] = v.(stmt.Statement)
}
}
return stmts, nil
}
// CompilePrepare compiles prepared statement, allows placeholder as expr.
@ -112,7 +126,7 @@ func CompilePrepare(src string) (stmt.Statement, []*expression.ParamMarker, erro
return nil, nil, nil
}
sm := sms[0]
return sm, l.ParamList, nil
return sm.(stmt.Statement), l.ParamList, nil
}
func prepareStmt(ctx context.Context, sqlText string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) {

View File

@ -735,6 +735,21 @@ func (s *testSessionSuite) TestIndex(c *C) {
c.Assert(err, IsNil)
c.Assert(rows, HasLen, 1)
match(c, rows[0], 1)
mustExecSQL(c, se, "drop table if exists t1, t2")
mustExecSQL(c, se, `
create table t1 (c1 int, primary key(c1));
create table t2 (c2 int, primary key(c2));
insert into t1 values (1), (2);
insert into t2 values (2);`)
r = mustExecSQL(c, se, "select * from t1 left join t2 on t1.c1 = t2.c2 order by t1.c1")
rows, err = r.Rows(-1, 0)
matches(c, rows, [][]interface{}{{1, nil}, {2, 2}})
r = mustExecSQL(c, se, "select * from t1 left join t2 on t1.c1 = t2.c2 where t2.c2 < 10")
rows, err = r.Rows(-1, 0)
matches(c, rows, [][]interface{}{{2, 2}})
}
func (s *testSessionSuite) TestMySQLTypes(c *C) {
@ -792,6 +807,11 @@ func (s *testSessionSuite) TestSelect(c *C) {
c.Assert(err, IsNil)
match(c, row, 1, 2)
r = mustExecSQL(c, se, "select 1, 2 from dual where not exists (select * from t where c1=2)")
row, err = r.FirstRow()
c.Assert(err, IsNil)
match(c, row, 1, 2)
r = mustExecSQL(c, se, "select 1, 2")
row, err = r.FirstRow()
c.Assert(err, IsNil)
@ -848,6 +868,20 @@ func (s *testSessionSuite) TestSelect(c *C) {
row, err = r.FirstRow()
c.Assert(err, IsNil)
c.Assert(row, IsNil)
mustExecSQL(c, se, "drop table if exists t1, t2, t3")
mustExecSQL(c, se, `
create table t1 (c1 int);
create table t2 (c2 int);
create table t3 (c3 int);
insert into t1 values (1), (2);
insert into t2 values (2);
insert into t3 values (3);`)
r = mustExecSQL(c, se, "select * from t1 left join t2 on t1.c1 = t2.c2 left join t3 on t1.c1 = t3.c3 order by t1.c1")
rows, err = r.Rows(-1, 0)
c.Assert(err, IsNil)
matches(c, rows, [][]interface{}{{1, nil, nil}, {2, 2, nil}})
}
func (s *testSessionSuite) TestSubQuery(c *C) {
@ -870,6 +904,8 @@ func (s *testSessionSuite) TestSubQuery(c *C) {
c.Assert(rows, HasLen, 2)
match(c, rows[0], 0)
match(c, rows[1], 2)
mustExecMatch(c, se, "select a.c1, a.c2 from (select c1 as c1, c1 as c2 from t1) as a", [][]interface{}{{1, 1}, {2, 2}})
}
func (s *testSessionSuite) TestShow(c *C) {
@ -888,7 +924,7 @@ func (s *testSessionSuite) TestShow(c *C) {
rows, err := r.Rows(-1, 0)
c.Assert(err, IsNil)
c.Assert(rows, HasLen, 1)
match(c, rows[0], "c", "int", "YES", "", nil, "")
match(c, rows[0], "c", "int(11)", "YES", "", nil, "")
r = mustExecSQL(c, se, "show collation where Charset = 'utf8' and Collation = 'utf8_bin'")
row, err = r.FirstRow()
@ -1067,6 +1103,86 @@ func (s *testSessionSuite) TestWhereLike(c *C) {
c.Assert(rows, HasLen, 6)
}
func (s *testSessionSuite) TestDefaultFlenBug(c *C) {
// If set unspecified column flen to 0, it will cause bug in union.
// This test is used to prevent the bug reappear.
store := newStore(c, s.dbName)
se := newSession(c, store, s.dbName)
mustExecSQL(c, se, "create table t1 (c double);")
mustExecSQL(c, se, "create table t2 (c double);")
mustExecSQL(c, se, "insert into t1 value (73);")
mustExecSQL(c, se, "insert into t2 value (930);")
// The data in the second src will be casted as the type of the first src.
// If use flen=0, it will be truncated.
r := mustExecSQL(c, se, "select c from t1 union select c from t2;")
rows, err := r.Rows(-1, 0)
c.Assert(err, IsNil)
c.Assert(rows, HasLen, 2)
c.Assert(rows[1][0], Equals, float64(930))
}
func (s *testSessionSuite) TestExecRestrictedSQL(c *C) {
store := newStore(c, s.dbName)
se := newSession(c, store, s.dbName).(*session)
r, err := se.ExecRestrictedSQL(se, "select 1;")
c.Assert(r, NotNil)
c.Assert(err, IsNil)
_, err = se.ExecRestrictedSQL(se, "select 1; select 2;")
c.Assert(err, NotNil)
_, err = se.ExecRestrictedSQL(se, "")
c.Assert(err, NotNil)
}
func (s *testSessionSuite) TestGroupBy(c *C) {
store := newStore(c, s.dbName)
se := newSession(c, store, s.dbName)
mustExecSQL(c, se, "drop table if exists t")
mustExecSQL(c, se, "create table t (c1 int, c2 int)")
mustExecSQL(c, se, "insert into t values (1,1), (2,2), (1,2), (1,3)")
mustExecMatch(c, se, "select c1 as c2, c2 from t group by c2 + 1", [][]interface{}{{1, 1}, {2, 2}, {1, 3}})
mustExecMatch(c, se, "select c1 as c2, count(c1) from t group by c2", [][]interface{}{{1, 1}, {2, 2}, {1, 1}})
mustExecMatch(c, se, "select t.c1, c1 from t group by c1", [][]interface{}{{1, 1}, {2, 2}})
mustExecMatch(c, se, "select t.c1 as a, c1 as a from t group by a", [][]interface{}{{1, 1}, {2, 2}})
mustExecFailed(c, se, "select c1 as a, c2 as a from t group by a")
mustExecFailed(c, se, "select c1 as c2, c2 from t group by c2")
mustExecFailed(c, se, "select sum(c1) as a from t group by a")
mustExecFailed(c, se, "select sum(c1) as a from t group by a + 1")
}
func (s *testSessionSuite) TestOrderBy(c *C) {
store := newStore(c, s.dbName)
se := newSession(c, store, s.dbName)
mustExecSQL(c, se, "drop table if exists t")
mustExecSQL(c, se, "create table t (c1 int, c2 int)")
mustExecSQL(c, se, "insert into t values (1,2), (2, 1)")
// Fix issue https://github.com/pingcap/tidb/issues/337
mustExecMatch(c, se, "select c1 as a, c1 as b from t order by c1", [][]interface{}{{1, 1}, {2, 2}})
mustExecMatch(c, se, "select c1 as a, t.c1 as a from t order by a desc", [][]interface{}{{2, 2}, {1, 1}})
mustExecMatch(c, se, "select c1 as c2 from t order by c2", [][]interface{}{{1}, {2}})
mustExecMatch(c, se, "select sum(c1) from t order by sum(c1)", [][]interface{}{{3}})
// TODO: now this test result is not same as MySQL, we will update it later.
mustExecMatch(c, se, "select c1 as c2 from t order by c2 + 1", [][]interface{}{{1}, {2}})
mustExecFailed(c, se, "select c1 as a, c2 as a from t order by a")
}
func (s *testSessionSuite) TestHaving(c *C) {
store := newStore(c, s.dbName)
se := newSession(c, store, s.dbName)
mustExecSQL(c, se, "drop table if exists t")
mustExecSQL(c, se, "create table t (c1 int, c2 int)")
mustExecSQL(c, se, "insert into t values (1,2), (2, 1)")
mustExecMatch(c, se, "select sum(c1) from t group by c1 having sum(c1)", [][]interface{}{{1}, {2}})
mustExecMatch(c, se, "select sum(c1) - 1 from t group by c1 having sum(c1) - 1", [][]interface{}{{1}})
}
func newSession(c *C, store kv.Storage, dbName string) Session {
se, err := CreateSession(store)
c.Assert(err, IsNil)
@ -1131,3 +1247,22 @@ func match(c *C, row []interface{}, expected ...interface{}) {
c.Assert(got, Equals, need)
}
}
func matches(c *C, rows [][]interface{}, expected [][]interface{}) {
c.Assert(len(rows), Equals, len(expected))
for i := 0; i < len(rows); i++ {
match(c, rows[i], expected[i]...)
}
}
func mustExecMatch(c *C, se Session, sql string, expected [][]interface{}) {
r := mustExecSQL(c, se, sql)
rows, err := r.Rows(-1, 0)
c.Assert(err, IsNil)
matches(c, rows, expected)
}
func mustExecFailed(c *C, se Session, sql string, args ...interface{}) {
_, err := exec(c, se, sql, args...)
c.Assert(err, NotNil)
}

View File

@ -0,0 +1,44 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlexec
import (
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/rset"
)
// RestrictedSQLExecutorKeyType is a dummy type to avoid naming collision in session.
type RestrictedSQLExecutorKeyType struct{}
// String implements Stringer interface.
func (k *RestrictedSQLExecutorKeyType) String() string {
return "restricted_sql_executor"
}
// RestrictedSQLExecutor is an interface provides executing restricted sql statement.
// Why we need this interface?
// When we execute some management statements, we need to operate system tables.
// For example when executing create user statement, we need to check if the user already
// exists in the mysql.User table and insert a new row if not exists. In this case, we need
// a convenience way to manipulate system tables. The most simple way is executing sql statement.
// In order to execute sql statement in stmts package, we add this interface to solve dependence problem.
// And in the same time, we do not want this interface becomes a general way to run sql statement.
// We hope this could be used with some restrictions such as only allowing system tables as target,
// do not allowing recursion call.
// For more infomation please refer to the comments in session.ExecRestrictedSQL().
type RestrictedSQLExecutor interface {
// ExecRestrictedSQL run sql statement in ctx with some restriction.
// This is implemented in session.go.
ExecRestrictedSQL(ctx context.Context, sql string) (rset.Recordset, error)
}

View File

@ -447,11 +447,11 @@ func (s *testTypeConvertSuite) TestStrToNum(c *C) {
}
func (s *testTypeConvertSuite) TestFieldTypeToStr(c *C) {
v := FieldTypeToStr(mysql.TypeDecimal, "not binary")
v := TypeToStr(mysql.TypeDecimal, "not binary")
c.Assert(v, Equals, type2Str[mysql.TypeDecimal])
v = FieldTypeToStr(mysql.TypeBlob, charset.CharsetBin)
v = TypeToStr(mysql.TypeBlob, charset.CharsetBin)
c.Assert(v, Equals, "blob")
v = FieldTypeToStr(mysql.TypeString, charset.CharsetBin)
v = TypeToStr(mysql.TypeString, charset.CharsetBin)
c.Assert(v, Equals, "binary")
}

View File

@ -23,7 +23,6 @@ import (
"strings"
"github.com/juju/errors"
"github.com/ngaut/log"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/charset"
@ -85,84 +84,13 @@ func TypeStr(tp byte) (r string) {
return type2Str[tp]
}
// TypeToStr converts tp to a string with an extra binary.
func TypeToStr(tp byte, binary bool) string {
switch tp {
case mysql.TypeBlob:
if binary {
return "text"
}
return "blob"
case mysql.TypeLongBlob:
if binary {
return "longtext"
}
return "longblob"
case mysql.TypeTinyBlob:
if binary {
return "tinytext"
}
return "tinyblob"
case mysql.TypeMediumBlob:
if binary {
return "mediumtext"
}
return "mediumblob"
case mysql.TypeVarchar:
if binary {
return "varbinary"
}
return "varchar"
case mysql.TypeString:
if binary {
return "binary"
}
return "char"
case mysql.TypeTiny:
return "tinyint"
case mysql.TypeShort:
return "smallint"
case mysql.TypeInt24:
return "mediumint"
case mysql.TypeLong:
return "int"
case mysql.TypeLonglong:
return "bigint"
case mysql.TypeFloat:
return "float"
case mysql.TypeDouble:
return "double"
case mysql.TypeDecimal, mysql.TypeNewDecimal:
return "decimal"
case mysql.TypeYear:
return "year"
case mysql.TypeDuration:
return "time"
case mysql.TypeDatetime:
return "datetime"
case mysql.TypeDate:
return "date"
case mysql.TypeTimestamp:
return "timestamp"
case mysql.TypeBit:
return "bit"
case mysql.TypeEnum:
return "enum"
case mysql.TypeSet:
return "set"
default:
log.Errorf("unkown type %d, binary %v", tp, binary)
}
return ""
}
// FieldTypeToStr converts a field to a string.
// TypeToStr converts a field to a string.
// It is used for converting Text to Blob,
// or converting Char to Binary.
// Args:
// tp: type enum
// cs: charset
func FieldTypeToStr(tp byte, cs string) (r string) {
func TypeToStr(tp byte, cs string) (r string) {
ts := type2Str[tp]
if cs != charset.CharsetBin {
return ts

View File

@ -59,8 +59,8 @@ func testTypeStr(c *C, tp byte, expect string) {
c.Assert(v, Equals, expect)
}
func testTypeToStr(c *C, tp byte, binary bool, expect string) {
v := TypeToStr(tp, binary)
func testTypeToStr(c *C, tp byte, charset string, expect string) {
v := TypeToStr(tp, charset)
c.Assert(v, Equals, expect)
}
@ -68,36 +68,36 @@ func (s *testTypeEtcSuite) TestTypeToStr(c *C) {
testTypeStr(c, mysql.TypeYear, "year")
testTypeStr(c, 0xdd, "")
testTypeToStr(c, mysql.TypeBlob, true, "text")
testTypeToStr(c, mysql.TypeLongBlob, true, "longtext")
testTypeToStr(c, mysql.TypeTinyBlob, true, "tinytext")
testTypeToStr(c, mysql.TypeMediumBlob, true, "mediumtext")
testTypeToStr(c, mysql.TypeVarchar, true, "varbinary")
testTypeToStr(c, mysql.TypeString, true, "binary")
testTypeToStr(c, mysql.TypeTiny, true, "tinyint")
testTypeToStr(c, mysql.TypeBlob, false, "blob")
testTypeToStr(c, mysql.TypeLongBlob, false, "longblob")
testTypeToStr(c, mysql.TypeTinyBlob, false, "tinyblob")
testTypeToStr(c, mysql.TypeMediumBlob, false, "mediumblob")
testTypeToStr(c, mysql.TypeVarchar, false, "varchar")
testTypeToStr(c, mysql.TypeString, false, "char")
testTypeToStr(c, mysql.TypeShort, true, "smallint")
testTypeToStr(c, mysql.TypeInt24, true, "mediumint")
testTypeToStr(c, mysql.TypeLong, true, "int")
testTypeToStr(c, mysql.TypeLonglong, true, "bigint")
testTypeToStr(c, mysql.TypeFloat, true, "float")
testTypeToStr(c, mysql.TypeDouble, true, "double")
testTypeToStr(c, mysql.TypeYear, true, "year")
testTypeToStr(c, mysql.TypeDuration, true, "time")
testTypeToStr(c, mysql.TypeDatetime, true, "datetime")
testTypeToStr(c, mysql.TypeDate, true, "date")
testTypeToStr(c, mysql.TypeTimestamp, true, "timestamp")
testTypeToStr(c, mysql.TypeNewDecimal, true, "decimal")
testTypeToStr(c, mysql.TypeDecimal, true, "decimal")
testTypeToStr(c, 0xdd, true, "")
testTypeToStr(c, mysql.TypeBit, true, "bit")
testTypeToStr(c, mysql.TypeEnum, true, "enum")
testTypeToStr(c, mysql.TypeSet, true, "set")
testTypeToStr(c, mysql.TypeBlob, "utf8", "text")
testTypeToStr(c, mysql.TypeLongBlob, "utf8", "longtext")
testTypeToStr(c, mysql.TypeTinyBlob, "utf8", "tinytext")
testTypeToStr(c, mysql.TypeMediumBlob, "utf8", "mediumtext")
testTypeToStr(c, mysql.TypeVarchar, "binary", "varbinary")
testTypeToStr(c, mysql.TypeString, "binary", "binary")
testTypeToStr(c, mysql.TypeTiny, "binary", "tinyint")
testTypeToStr(c, mysql.TypeBlob, "binary", "blob")
testTypeToStr(c, mysql.TypeLongBlob, "binary", "longblob")
testTypeToStr(c, mysql.TypeTinyBlob, "binary", "tinyblob")
testTypeToStr(c, mysql.TypeMediumBlob, "binary", "mediumblob")
testTypeToStr(c, mysql.TypeVarchar, "utf8", "varchar")
testTypeToStr(c, mysql.TypeString, "utf8", "char")
testTypeToStr(c, mysql.TypeShort, "binary", "smallint")
testTypeToStr(c, mysql.TypeInt24, "binary", "mediumint")
testTypeToStr(c, mysql.TypeLong, "binary", "int")
testTypeToStr(c, mysql.TypeLonglong, "binary", "bigint")
testTypeToStr(c, mysql.TypeFloat, "binary", "float")
testTypeToStr(c, mysql.TypeDouble, "binary", "double")
testTypeToStr(c, mysql.TypeYear, "binary", "year")
testTypeToStr(c, mysql.TypeDuration, "binary", "time")
testTypeToStr(c, mysql.TypeDatetime, "binary", "datetime")
testTypeToStr(c, mysql.TypeDate, "binary", "date")
testTypeToStr(c, mysql.TypeTimestamp, "binary", "timestamp")
testTypeToStr(c, mysql.TypeNewDecimal, "binary", "decimal")
testTypeToStr(c, mysql.TypeDecimal, "binary", "decimal")
testTypeToStr(c, 0xdd, "binary", "")
testTypeToStr(c, mysql.TypeBit, "binary", "bit")
testTypeToStr(c, mysql.TypeEnum, "binary", "enum")
testTypeToStr(c, mysql.TypeSet, "binary", "set")
}
func (s *testTypeEtcSuite) TestEOFAsNil(c *C) {

View File

@ -52,45 +52,62 @@ func NewFieldType(tp byte) *FieldType {
}
}
// String joins the information of FieldType and
// returns a string.
func (ft *FieldType) String() string {
ts := FieldTypeToStr(ft.Tp, ft.Charset)
ans := []string{ts}
// CompactStr only considers Tp/CharsetBin/Flen/Deimal.
// This is used for showing column type in infoschema.
func (ft *FieldType) CompactStr() string {
ts := TypeToStr(ft.Tp, ft.Charset)
suffix := ""
switch ft.Tp {
case mysql.TypeEnum, mysql.TypeSet:
// Format is ENUM ('e1', 'e2') or SET ('e1', 'e2')
ans = append(ans, fmt.Sprintf("('%s')", strings.Join(ft.Elems, "','")))
es := make([]string, 0, len(ft.Elems))
for _, e := range ft.Elems {
e = strings.Replace(e, "'", "''", -1)
es = append(es, e)
}
suffix = fmt.Sprintf("('%s')", strings.Join(es, "','"))
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
if ft.Decimal != UnspecifiedLength && ft.Decimal != 0 {
suffix = fmt.Sprintf("(%d)", ft.Decimal)
}
default:
if ft.Flen != UnspecifiedLength {
if ft.Decimal == UnspecifiedLength {
ans = append(ans, fmt.Sprintf("(%d)", ft.Flen))
if ft.Tp != mysql.TypeFloat && ft.Tp != mysql.TypeDouble {
suffix = fmt.Sprintf("(%d)", ft.Flen)
}
} else {
ans = append(ans, fmt.Sprintf("(%d, %d)", ft.Flen, ft.Decimal))
suffix = fmt.Sprintf("(%d,%d)", ft.Flen, ft.Decimal)
}
} else if ft.Decimal != UnspecifiedLength {
ans = append(ans, fmt.Sprintf("(%d)", ft.Decimal))
suffix = fmt.Sprintf("(%d)", ft.Decimal)
}
}
return ts + suffix
}
// String joins the information of FieldType and
// returns a string.
func (ft *FieldType) String() string {
strs := []string{ft.CompactStr()}
if mysql.HasUnsignedFlag(ft.Flag) {
ans = append(ans, "UNSIGNED")
strs = append(strs, "UNSIGNED")
}
if mysql.HasZerofillFlag(ft.Flag) {
ans = append(ans, "ZEROFILL")
strs = append(strs, "ZEROFILL")
}
if mysql.HasBinaryFlag(ft.Flag) {
ans = append(ans, "BINARY")
strs = append(strs, "BINARY")
}
if IsTypeChar(ft.Tp) || IsTypeBlob(ft.Tp) {
if ft.Charset != "" && ft.Charset != charset.CharsetBin {
ans = append(ans, fmt.Sprintf("CHARACTER SET %s", ft.Charset))
strs = append(strs, fmt.Sprintf("CHARACTER SET %s", ft.Charset))
}
if ft.Collate != "" && ft.Collate != charset.CharsetBin {
ans = append(ans, fmt.Sprintf("COLLATE %s", ft.Collate))
strs = append(strs, fmt.Sprintf("COLLATE %s", ft.Collate))
}
}
return strings.Join(ans, " ")
return strings.Join(strs, " ")
}

View File

@ -28,33 +28,80 @@ func (s *testFieldTypeSuite) TestFieldType(c *C) {
c.Assert(ft.Flen, Equals, UnspecifiedLength)
c.Assert(ft.Decimal, Equals, UnspecifiedLength)
ft.Decimal = 5
c.Assert(ft.String(), Equals, "time (5)")
c.Assert(ft.String(), Equals, "time(5)")
ft.Tp = mysql.TypeLong
ft.Flag |= mysql.UnsignedFlag | mysql.ZerofillFlag
c.Assert(ft.String(), Equals, "int (5) UNSIGNED ZEROFILL")
c.Assert(ft.String(), Equals, "int(5) UNSIGNED ZEROFILL")
ft = NewFieldType(mysql.TypeFloat)
ft.Flen = 10
ft.Decimal = 3
c.Assert(ft.String(), Equals, "float (10, 3)")
c.Assert(ft.String(), Equals, "float(10,3)")
ft = NewFieldType(mysql.TypeFloat)
ft.Flen = 10
ft.Decimal = -1
c.Assert(ft.String(), Equals, "float")
ft = NewFieldType(mysql.TypeDouble)
ft.Flen = 10
ft.Decimal = 3
c.Assert(ft.String(), Equals, "double(10,3)")
ft = NewFieldType(mysql.TypeDouble)
ft.Flen = 10
ft.Decimal = -1
c.Assert(ft.String(), Equals, "double")
ft = NewFieldType(mysql.TypeBlob)
ft.Flen = 10
ft.Charset = "UTF8"
ft.Collate = "UTF8_UNICODE_GI"
c.Assert(ft.String(), Equals, "text (10) CHARACTER SET UTF8 COLLATE UTF8_UNICODE_GI")
c.Assert(ft.String(), Equals, "text(10) CHARACTER SET UTF8 COLLATE UTF8_UNICODE_GI")
ft = NewFieldType(mysql.TypeVarchar)
ft.Flen = 10
ft.Flag |= mysql.BinaryFlag
c.Assert(ft.String(), Equals, "varchar (10) BINARY")
c.Assert(ft.String(), Equals, "varchar(10) BINARY")
ft = NewFieldType(mysql.TypeEnum)
ft.Elems = []string{"a", "b"}
c.Assert(ft.String(), Equals, "enum ('a','b')")
c.Assert(ft.String(), Equals, "enum('a','b')")
ft = NewFieldType(mysql.TypeEnum)
ft.Elems = []string{"'a'", "'b'"}
c.Assert(ft.String(), Equals, "enum('''a''','''b''')")
ft = NewFieldType(mysql.TypeSet)
ft.Elems = []string{"a", "b"}
c.Assert(ft.String(), Equals, "set ('a','b')")
c.Assert(ft.String(), Equals, "set('a','b')")
ft = NewFieldType(mysql.TypeSet)
ft.Elems = []string{"'a'", "'b'"}
c.Assert(ft.String(), Equals, "set('''a''','''b''')")
ft = NewFieldType(mysql.TypeTimestamp)
ft.Flen = 8
ft.Decimal = 2
c.Assert(ft.String(), Equals, "timestamp(2)")
ft = NewFieldType(mysql.TypeTimestamp)
ft.Flen = 8
ft.Decimal = 0
c.Assert(ft.String(), Equals, "timestamp")
ft = NewFieldType(mysql.TypeDatetime)
ft.Flen = 8
ft.Decimal = 2
c.Assert(ft.String(), Equals, "datetime(2)")
ft = NewFieldType(mysql.TypeDatetime)
ft.Flen = 8
ft.Decimal = 0
c.Assert(ft.String(), Equals, "datetime")
ft = NewFieldType(mysql.TypeDate)
ft.Flen = 8
ft.Decimal = 2
c.Assert(ft.String(), Equals, "date(2)")
ft = NewFieldType(mysql.TypeDate)
ft.Flen = 8
ft.Decimal = 0
c.Assert(ft.String(), Equals, "date")
}