*: refact insert logic (#2252)

This commit is contained in:
Han Fei
2016-12-20 17:38:23 +08:00
committed by GitHub
parent f4e2bcacb2
commit a89fa8a8d3
11 changed files with 211 additions and 111 deletions

View File

@ -69,6 +69,7 @@ const (
RowFunc = "row"
SetVar = "setvar"
GetVar = "getvar"
Values = "values"
// common functions
Coalesce = "coalesce"

View File

@ -19,6 +19,7 @@ import (
"time"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
@ -558,3 +559,21 @@ func builtinReleaseLock(args []types.Datum, _ context.Context) (d types.Datum, e
d.SetInt64(1)
return d, nil
}
// BuildinValuesFactory generates values builtin function.
func BuildinValuesFactory(v *ast.ValuesExpr) BuiltinFunc {
return func(_ []types.Datum, ctx context.Context) (d types.Datum, err error) {
values := ctx.GetSessionVars().CurrInsertValues
if values == nil {
err = errors.New("Session current insert values is nil")
return
}
row := values.([]types.Datum)
offset := v.Column.Refer.Column.Offset
if len(row) > offset {
return row[offset], nil
}
err = errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), offset)
return
}
}

View File

@ -52,23 +52,6 @@ func Eval(ctx context.Context, expr ast.ExprNode) (d types.Datum, err error) {
return *expr.GetDatum(), nil
}
// EvalBool evalueates an expression to a boolean value.
func EvalBool(ctx context.Context, expr ast.ExprNode) (bool, error) {
val, err := Eval(ctx, expr)
if err != nil {
return false, errors.Trace(err)
}
if val.IsNull() {
return false, nil
}
i, err := val.ToBool(ctx.GetSessionVars().StmtCtx)
if err != nil {
return false, errors.Trace(err)
}
return i != 0, nil
}
func boolToInt64(v bool) int64 {
if v {
return int64(1)
@ -589,13 +572,13 @@ func (e *Evaluator) values(v *ast.ValuesExpr) bool {
}
row := values.([]types.Datum)
off := v.Column.Refer.Column.Offset
if len(row) > off {
v.SetDatum(row[off])
offset := v.Column.Refer.Column.Offset
if len(row) > offset {
v.SetDatum(row[offset])
return true
}
e.err = errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), off)
e.err = errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), offset)
return false
}

View File

@ -241,24 +241,7 @@ func (b *executorBuilder) buildInsert(v *plan.Insert) Executor {
if len(v.GetChildren()) > 0 {
ivs.SelectExec = b.build(v.GetChildByIndex(0))
}
// Get Table
ts, ok := v.Table.TableRefs.Left.(*ast.TableSource)
if !ok {
b.err = errors.New("Can not get table")
return nil
}
tn, ok := ts.Source.(*ast.TableName)
if !ok {
b.err = errors.New("Can not get table")
return nil
}
tableInfo := tn.TableInfo
tbl, ok := b.is.TableByID(tableInfo.ID)
if !ok {
b.err = errors.Errorf("Can not get table %d", tableInfo.ID)
return nil
}
ivs.Table = tbl
ivs.Table = v.Table
if v.IsReplace {
return b.buildReplace(ivs)
}
@ -268,8 +251,6 @@ func (b *executorBuilder) buildInsert(v *plan.Insert) Executor {
Priority: v.Priority,
Ignore: v.Ignore,
}
// fields is used to evaluate values expr.
insert.fields = ts.GetResultFields()
return insert
}

View File

@ -21,7 +21,6 @@ import (
"github.com/ngaut/log"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/evaluator"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/mysql"
@ -546,8 +545,8 @@ type InsertValues struct {
Table table.Table
Columns []*ast.ColumnName
Lists [][]ast.ExprNode
Setlist []*ast.Assignment
Lists [][]expression.Expression
Setlist []*expression.Assignment
IsPrepare bool
}
@ -555,8 +554,7 @@ type InsertValues struct {
type InsertExec struct {
*InsertValues
OnDuplicate []*ast.Assignment
fields []*ast.ResultField
OnDuplicate []*expression.Assignment
Priority int
Ignore bool
@ -654,7 +652,7 @@ func (e *InsertValues) getColumns(tableCols []*table.Column) ([]*table.Column, e
// Process `set` type column.
columns := make([]string, 0, len(e.Setlist))
for _, v := range e.Setlist {
columns = append(columns, v.Column.Name.O)
columns = append(columns, v.Col.ColName.O)
}
cols, err = table.FindCols(tableCols, columns)
@ -696,7 +694,7 @@ func (e *InsertValues) fillValueList() error {
if len(e.Lists) > 0 {
return errors.Errorf("INSERT INTO %s: set type should not use values", e.Table)
}
var l []ast.ExprNode
l := make([]expression.Expression, 0, len(e.Setlist))
for _, v := range e.Setlist {
l = append(l, v.Expr)
}
@ -706,6 +704,7 @@ func (e *InsertValues) fillValueList() error {
}
func (e *InsertValues) checkValueCount(insertValueCount, valueCount, num int, cols []*table.Column) error {
// TODO: This check should be done in plan builder.
if insertValueCount != valueCount {
// "insert into t values (), ()" is valid.
// "insert into t values (), (1)" is not valid.
@ -723,30 +722,12 @@ func (e *InsertValues) checkValueCount(insertValueCount, valueCount, num int, co
return nil
}
func (e *InsertValues) getColumnDefaultValues(cols []*table.Column) (map[string]types.Datum, error) {
defaultValMap := map[string]types.Datum{}
for _, col := range cols {
if value, ok, err := table.GetColDefaultValue(e.ctx, col.ToInfo()); ok {
if err != nil {
return nil, errors.Trace(err)
}
defaultValMap[col.Name.L] = value
}
}
return defaultValMap, nil
}
func (e *InsertValues) getRows(cols []*table.Column) (rows [][]types.Datum, err error) {
// process `insert|replace ... set x=y...`
if err = e.fillValueList(); err != nil {
return nil, errors.Trace(err)
}
defaultVals, err := e.getColumnDefaultValues(e.Table.Cols())
if err != nil {
return nil, errors.Trace(err)
}
rows = make([][]types.Datum, len(e.Lists))
length := len(e.Lists[0])
for i, list := range e.Lists {
@ -754,7 +735,7 @@ func (e *InsertValues) getRows(cols []*table.Column) (rows [][]types.Datum, err
return nil, errors.Trace(err)
}
e.currRow = i
rows[i], err = e.getRow(cols, list, defaultVals)
rows[i], err = e.getRow(cols, list)
if err != nil {
return nil, errors.Trace(err)
}
@ -762,28 +743,13 @@ func (e *InsertValues) getRows(cols []*table.Column) (rows [][]types.Datum, err
return
}
func (e *InsertValues) getRow(cols []*table.Column, list []ast.ExprNode, defaultVals map[string]types.Datum) ([]types.Datum, error) {
func (e *InsertValues) getRow(cols []*table.Column, list []expression.Expression) ([]types.Datum, error) {
vals := make([]types.Datum, len(list))
var err error
for i, expr := range list {
if d, ok := expr.(*ast.DefaultExpr); ok {
cn := d.Name
if cn == nil {
vals[i] = defaultVals[cols[i].Name.L]
continue
}
var found bool
vals[i], found = defaultVals[cn.Name.L]
if !found {
return nil, errors.Errorf("default column not found - %s", cn.Name.O)
}
} else {
var val types.Datum
val, err = evaluator.Eval(e.ctx, expr)
vals[i] = val
if err != nil {
return nil, errors.Trace(err)
}
val, err := expr.Eval(nil, e.ctx)
vals[i] = val
if err != nil {
return nil, errors.Trace(err)
}
}
return e.fillRowData(cols, vals, false)
@ -914,17 +880,12 @@ func (e *InsertValues) initDefaultValues(row []types.Datum, marked map[int]struc
// onDuplicateUpdate updates the duplicate row.
// TODO: Report rows affected and last insert id.
func (e *InsertExec) onDuplicateUpdate(row []types.Datum, h int64, cols map[int]*ast.Assignment) error {
func (e *InsertExec) onDuplicateUpdate(row []types.Datum, h int64, cols map[int]*expression.Assignment) error {
data, err := e.Table.Row(e.ctx, h)
if err != nil {
return errors.Trace(err)
}
// for evaluating ColumnNameExpr
for i, rf := range e.fields {
rf.Expr.SetValue(data[i].GetValue())
}
// for evaluating ValuesExpr
// See http://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
e.ctx.GetSessionVars().CurrInsertValues = row
// evaluate assignment
@ -935,7 +896,7 @@ func (e *InsertExec) onDuplicateUpdate(row []types.Datum, h int64, cols map[int]
newData[i] = c
continue
}
val, err1 := evaluator.Eval(e.ctx, asgn.Expr)
val, err1 := asgn.Expr.Eval(data, e.ctx)
if err1 != nil {
return errors.Trace(err1)
}
@ -968,12 +929,12 @@ func findColumnByName(t table.Table, tableName, colName string) (*table.Column,
return c, nil
}
func getOnDuplicateUpdateColumns(assignList []*ast.Assignment, t table.Table) (map[int]*ast.Assignment, error) {
m := make(map[int]*ast.Assignment, len(assignList))
func getOnDuplicateUpdateColumns(assignList []*expression.Assignment, t table.Table) (map[int]*expression.Assignment, error) {
m := make(map[int]*expression.Assignment, len(assignList))
for _, v := range assignList {
col := v.Column
c, err := findColumnByName(t, col.Table.L, col.Name.L)
col := v.Col
c, err := findColumnByName(t, col.TblName.L, col.ColName.L)
if err != nil {
return nil, errors.Trace(err)
}

View File

@ -21,6 +21,7 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/types"
@ -261,3 +262,18 @@ func ResultFieldsToSchema(fields []*ast.ResultField) Schema {
}
return schema
}
// TableInfo2Schema converts table info to schema.
func TableInfo2Schema(tbl *model.TableInfo) Schema {
schema := make(Schema, 0, len(tbl.Columns))
for i, col := range tbl.Columns {
newCol := &Column{
ColName: col.Name,
TblName: tbl.Name,
RetType: &col.FieldType,
Position: i,
}
schema = append(schema, newCol)
}
return schema
}

View File

@ -180,6 +180,9 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
case *ast.SubqueryExpr:
return er.handleScalarSubquery(v)
case *ast.ParenthesesExpr:
case *ast.ValuesExpr:
er.valuesToScalarFunc(v)
return inNode, true
default:
er.asScalar = true
}
@ -413,7 +416,7 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool)
switch v := inNode.(type) {
case *ast.AggregateFuncExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause,
*ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr:
*ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr:
case *ast.ValueExpr:
value := &expression.Constant{Value: v.Datum, RetType: &v.Type}
er.ctxStack = append(er.ctxStack, value)
@ -855,3 +858,13 @@ func (er *expressionRewriter) castToScalarFunc(v *ast.FuncCastExpr) {
ArgValues: make([]types.Datum, 1)}
er.ctxStack[len(er.ctxStack)-1] = function
}
func (er *expressionRewriter) valuesToScalarFunc(v *ast.ValuesExpr) {
bt := evaluator.BuildinValuesFactory(v)
function := &expression.ScalarFunction{
FuncName: model.NewCIStr(ast.Values),
RetType: &v.Type,
Function: bt,
}
er.ctxStack = append(er.ctxStack, function)
}

View File

@ -553,10 +553,10 @@ func (s *testPlanSuite) TestPlanBuilder(c *C) {
sql: "explain select * from t union all select * from t limit 1, 1",
plan: "UnionAll{Table(t)->Table(t)->Limit}->*plan.Explain",
},
{
sql: "insert into t select * from t",
plan: "DataScan(t)->Projection->*plan.Insert",
},
//{
// sql: "insert into t select * from t",
// plan: "DataScan(t)->Projection->*plan.Insert",
//},
{
sql: "show columns from t where `Key` = 'pri' like 't*'",
plan: "*plan.Show->Selection",

View File

@ -22,6 +22,7 @@ import (
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/types"
@ -31,14 +32,23 @@ import (
var (
ErrUnsupportedType = terror.ClassOptimizerPlan.New(CodeUnsupportedType, "Unsupported type")
SystemInternalErrorType = terror.ClassOptimizerPlan.New(SystemInternalError, "System internal error")
ErrUnknownColumn = terror.ClassOptimizerPlan.New(CodeUnknownColumn, "Unknown column '%s' in '%s'")
)
// Error codes.
const (
CodeUnsupportedType terror.ErrCode = 1
SystemInternalError terror.ErrCode = 2
CodeUnknownColumn terror.ErrCode = 1054
)
func init() {
tableMySQLErrCodes := map[terror.ErrCode]uint16{
CodeUnknownColumn: mysql.ErrBadField,
}
terror.ErrClassToMySQLCodes[terror.ClassOptimizerPlan] = tableMySQLErrCodes
}
// planBuilder builds Plan from an ast.Node.
// It just builds the ast node straightforwardly.
type planBuilder struct {
@ -435,18 +445,124 @@ func (b *planBuilder) buildSimple(node ast.StmtNode) Plan {
return &Simple{Statement: node}
}
func (b *planBuilder) getDefaultValue(col *table.Column) (*expression.Constant, error) {
if value, ok, err := table.GetColDefaultValue(b.ctx, col.ToInfo()); ok {
if err != nil {
return nil, errors.Trace(err)
}
return &expression.Constant{Value: value, RetType: &col.FieldType}, nil
}
return &expression.Constant{RetType: &col.FieldType}, nil
}
func (b *planBuilder) findDefaultValue(cols []*table.Column, name *ast.ColumnName) (*expression.Constant, error) {
for _, col := range cols {
if col.Name.L == name.Name.L {
return b.getDefaultValue(col)
}
}
return nil, ErrUnknownColumn.GenByArgs(name.Name.O, "field_list")
}
func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
// Get Table
ts, ok := insert.Table.TableRefs.Left.(*ast.TableSource)
if !ok {
b.err = infoschema.ErrTableNotExists.GenByArgs()
return nil
}
tn, ok := ts.Source.(*ast.TableName)
if !ok {
b.err = infoschema.ErrTableNotExists.GenByArgs()
return nil
}
tableInfo := tn.TableInfo
schema := expression.TableInfo2Schema(tableInfo)
table, ok := b.is.TableByID(tableInfo.ID)
if !ok {
b.err = errors.Errorf("Can't get table %s.", tableInfo.Name.O)
return nil
}
insertPlan := &Insert{
Table: insert.Table,
Table: table,
Columns: insert.Columns,
Lists: insert.Lists,
Setlist: insert.Setlist,
OnDuplicate: insert.OnDuplicate,
tableSchema: schema,
IsReplace: insert.IsReplace,
Priority: insert.Priority,
Ignore: insert.Ignore,
baseLogicalPlan: newBaseLogicalPlan(Ins, b.allocator),
}
cols := table.Cols()
for _, valuesItem := range insert.Lists {
exprList := make([]expression.Expression, 0, len(valuesItem))
for i, valueItem := range valuesItem {
var expr expression.Expression
var err error
if dft, ok := valueItem.(*ast.DefaultExpr); ok {
if dft.Name != nil {
expr, err = b.findDefaultValue(cols, dft.Name)
} else {
expr, err = b.getDefaultValue(cols[i])
}
} else if val, ok := valueItem.(*ast.ValueExpr); ok {
expr = &expression.Constant{
Value: val.Datum,
RetType: &val.Type,
}
} else {
expr, _, err = b.rewrite(valueItem, nil, nil, true)
}
if err != nil {
b.err = errors.Trace(err)
}
exprList = append(exprList, expr)
}
insertPlan.Lists = append(insertPlan.Lists, exprList)
}
for _, assign := range insert.Setlist {
col, err := schema.FindColumn(assign.Column)
if err != nil {
b.err = errors.Trace(err)
return nil
}
if col == nil {
b.err = errors.Errorf("Can't find column %s", assign.Column)
return nil
}
// Here we keep different behaviours with MySQL. MySQL allow set a = b, b = a and the result is NULL, NULL.
// It's unreasonable.
expr, _, err := b.rewrite(assign.Expr, nil, nil, true)
if err != nil {
b.err = errors.Trace(err)
return nil
}
insertPlan.Setlist = append(insertPlan.Setlist, &expression.Assignment{
Col: col,
Expr: expr,
})
}
mockTablePlan := &TableDual{}
mockTablePlan.SetSchema(schema)
for _, assign := range insert.OnDuplicate {
col, err := schema.FindColumn(assign.Column)
if err != nil {
b.err = errors.Trace(err)
return nil
}
if col == nil {
b.err = errors.Errorf("Can't find column %s", assign.Column)
return nil
}
expr, _, err := b.rewrite(assign.Expr, mockTablePlan, nil, true)
if err != nil {
b.err = errors.Trace(err)
return nil
}
insertPlan.OnDuplicate = append(insertPlan.OnDuplicate, &expression.Assignment{
Col: col,
Expr: expr,
})
}
insertPlan.initIDAndContext(b.ctx)
insertPlan.self = insertPlan
if insert.Select != nil {

View File

@ -20,6 +20,7 @@ import (
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/util/types"
)
@ -176,11 +177,12 @@ type Simple struct {
type Insert struct {
baseLogicalPlan
Table *ast.TableRefsClause
Table table.Table
tableSchema expression.Schema
Columns []*ast.ColumnName
Lists [][]ast.ExprNode
Setlist []*ast.Assignment
OnDuplicate []*ast.Assignment
Lists [][]expression.Expression
Setlist []*expression.Assignment
OnDuplicate []*expression.Assignment
IsReplace bool
Priority int

View File

@ -130,3 +130,11 @@ func (p *Update) ResolveIndicesAndCorCols() {
}
p.OrderedList = orderedList
}
// ResolveIndicesAndCorCols implements LogicalPlan interface.
func (p *Insert) ResolveIndicesAndCorCols() {
p.baseLogicalPlan.ResolveIndicesAndCorCols()
for _, asgn := range p.OnDuplicate {
asgn.Expr.ResolveIndices(p.tableSchema)
}
}