diff --git a/column/column.go b/column/column.go index 1e5a16cbd5..fee76a9e45 100644 --- a/column/column.go +++ b/column/column.go @@ -161,7 +161,12 @@ func (c *Col) GetTypeDesc() string { switch c.Tp { case mysql.TypeSet, mysql.TypeEnum: // Format is ENUM ('e1', 'e2') or SET ('e1', 'e2') - buf.WriteString(fmt.Sprintf("('%s')", strings.Join(c.Elems, "','"))) + // If elem contain ', we will convert ' -> '' + elems := make([]string, len(c.Elems)) + for i := range elems { + elems[i] = strings.Replace(c.Elems[i], "'", "''", -1) + } + buf.WriteString(fmt.Sprintf("('%s')", strings.Join(elems, "','"))) default: if c.Flen != -1 { if c.Decimal == -1 { diff --git a/column/column_test.go b/column/column_test.go index 5ff94d0c68..11baace5d6 100644 --- a/column/column_test.go +++ b/column/column_test.go @@ -52,6 +52,9 @@ func (s *testColumnSuite) TestString(c *C) { col.Elems = []string{"a", "b"} c.Assert(col.GetTypeDesc(), Equals, "enum('a','b')") + + col.Elems = []string{"'a'", "b"} + c.Assert(col.GetTypeDesc(), Equals, "enum('''a''','b')") } func (s *testColumnSuite) TestFind(c *C) { diff --git a/stmt/stmts/stmt_helper.go b/stmt/stmts/stmt_helper.go index e9b3b1cdf0..d8ebe8682a 100644 --- a/stmt/stmts/stmt_helper.go +++ b/stmt/stmts/stmt_helper.go @@ -25,7 +25,7 @@ import ( func getDefaultValue(ctx context.Context, c *column.Col) (interface{}, bool, error) { // Check no default value flag. - if mysql.HasNoDefaultValueFlag(c.Flag) { + if mysql.HasNoDefaultValueFlag(c.Flag) && c.Tp != mysql.TypeEnum { return nil, false, errors.Errorf("Field '%s' doesn't have a default value", c.Name) } @@ -41,6 +41,12 @@ func getDefaultValue(ctx context.Context, c *column.Col) (interface{}, bool, err } return value, true, nil + } else if c.Tp == mysql.TypeEnum { + // For enum type, if no default value and not null is set, + // the default value is the first element of the enum list + if c.DefaultValue == nil && mysql.HasNotNullFlag(c.Flag) { + return c.FieldType.Elems[0], true, nil + } } return c.DefaultValue, true, nil diff --git a/stmt/stmts/stmt_helper_test.go b/stmt/stmts/stmt_helper_test.go index 1ce496120d..a9d50737d2 100644 --- a/stmt/stmts/stmt_helper_test.go +++ b/stmt/stmts/stmt_helper_test.go @@ -53,4 +53,27 @@ func (s *testStmtSuite) TestGetColDefaultValue(c *C) { testSQL = " insert helper_test (c1) values (1);" mustExec(c, s.testDB, testSQL) + + testSQL = `drop table if exists helper_test; + create table helper_test (c1 enum("a"), c2 enum("b", "e") not null, c3 enum("c") default "c", c4 enum("d") default "d" not null);` + mustExec(c, s.testDB, testSQL) + + testSQL = "insert into helper_test values();" + mustExec(c, s.testDB, testSQL) + + row := s.testDB.QueryRow("select * from helper_test") + var ( + v1 interface{} + v2 interface{} + v3 interface{} + v4 interface{} + ) + + err = row.Scan(&v1, &v2, &v3, &v4) + c.Assert(err, IsNil) + c.Assert(v1, IsNil) + c.Assert(v2, Equals, "b") + c.Assert(v3, Equals, "c") + c.Assert(v4, Equals, "d") + }