importinto: use same type context flag setting as insert (#58606)

close pingcap/tidb#58443
This commit is contained in:
D3Hunter
2024-12-31 15:02:33 +08:00
committed by GitHub
parent 42d4fae449
commit 284a3ee23c
5 changed files with 65 additions and 8 deletions

View File

@ -1049,6 +1049,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
errLevels := sc.ErrLevels()
errLevels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn
inImportInto := false
switch stmt := s.(type) {
// `ResetUpdateStmtCtx` and `ResetDeleteStmtCtx` may modify the flags, so we'll need to store them.
case *ast.UpdateStmt:
@ -1077,12 +1078,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
!strictSQLMode || stmt.IgnoreErr,
)
sc.Priority = stmt.Priority
sc.SetTypeFlags(sc.TypeFlags().
WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr).
WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()).
WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() ||
!vars.SQLMode.HasNoZeroDateMode() || !strictSQLMode || stmt.IgnoreErr ||
vars.SQLMode.HasAllowInvalidDatesMode()))
sc.SetTypeFlags(util.GetTypeFlagsForInsert(sc.TypeFlags(), vars.SQLMode, stmt.IgnoreErr))
case *ast.CreateTableStmt, *ast.AlterTableStmt:
sc.InCreateOrAlterStmt = true
sc.SetTypeFlags(sc.TypeFlags().
@ -1096,6 +1092,9 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.InLoadDataStmt = true
// return warning instead of error when load data meet no partition for value
errLevels[errctx.ErrGroupNoMatchedPartition] = errctx.LevelWarn
case *ast.ImportIntoStmt:
inImportInto = true
sc.SetTypeFlags(util.GetTypeFlagsForImportInto(sc.TypeFlags(), vars.SQLMode))
case *ast.SelectStmt:
sc.InSelectStmt = true
@ -1153,7 +1152,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
// WithAllowNegativeToUnsigned with false value indicates values less than 0 should be clipped to 0 for unsigned integer types.
// This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode.
// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
WithAllowNegativeToUnsigned(!sc.InInsertStmt && !sc.InLoadDataStmt && !sc.InUpdateStmt && !sc.InCreateOrAlterStmt),
WithAllowNegativeToUnsigned(!sc.InInsertStmt && !sc.InLoadDataStmt && !inImportInto && !sc.InUpdateStmt && !sc.InCreateOrAlterStmt),
)
vars.PlanCacheParams.Reset()

View File

@ -15,12 +15,15 @@
package executor_test
import (
"fmt"
"testing"
"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/executor"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/util/mock"
"github.com/stretchr/testify/require"
)
func BenchmarkResetContextOfStmt(b *testing.B) {
@ -31,3 +34,37 @@ func BenchmarkResetContextOfStmt(b *testing.B) {
executor.ResetContextOfStmt(ctx, stmt)
}
}
func TestImportIntoShouldHaveSameFlagsAsInsert(t *testing.T) {
insertStmt := &ast.InsertStmt{}
importStmt := &ast.ImportIntoStmt{}
insertCtx := mock.NewContext()
importCtx := mock.NewContext()
insertCtx.BindDomain(&domain.Domain{})
importCtx.BindDomain(&domain.Domain{})
for _, modeStr := range []string{
"",
"IGNORE_SPACE",
"STRICT_TRANS_TABLES",
"STRICT_ALL_TABLES",
"ALLOW_INVALID_DATES",
"NO_ZERO_IN_DATE",
"NO_ZERO_DATE",
"NO_ZERO_IN_DATE,STRICT_ALL_TABLES",
"NO_ZERO_DATE,STRICT_ALL_TABLES",
"NO_ZERO_IN_DATE,NO_ZERO_DATE,STRICT_ALL_TABLES",
} {
t.Run(fmt.Sprintf("mode %s", modeStr), func(t *testing.T) {
mode, err := mysql.GetSQLMode(modeStr)
require.NoError(t, err)
insertCtx.GetSessionVars().SQLMode = mode
require.NoError(t, executor.ResetContextOfStmt(insertCtx, insertStmt))
importCtx.GetSessionVars().SQLMode = mode
require.NoError(t, executor.ResetContextOfStmt(importCtx, importStmt))
insertTypeCtx := insertCtx.GetSessionVars().StmtCtx.TypeCtx()
importTypeCtx := importCtx.GetSessionVars().StmtCtx.TypeCtx()
require.EqualValues(t, insertTypeCtx.Flags(), importTypeCtx.Flags())
})
}
}

View File

@ -34,7 +34,9 @@ const (
// FlagTruncateAsWarning indicates to append the truncate error to warnings instead of returning it to user.
FlagTruncateAsWarning
// FlagAllowNegativeToUnsigned indicates to allow the casting from negative to unsigned int.
// When this flag is not set by default, casting a negative value to unsigned results an overflow error.
// When this flag is not set by default, casting a negative value to unsigned
// results an overflow error, but if SQL mode is not strict, it's converted
// to 0 with a warning.
// Otherwise, a negative value will be cast to the corresponding unsigned value without any error.
// For example, when casting -1 to an unsigned bigint with `FlagAllowNegativeToUnsigned` set,
// we will get `18446744073709551615` which is the biggest unsigned value.

View File

@ -41,6 +41,7 @@ go_library(
"//pkg/session/cursor",
"//pkg/session/txninfo",
"//pkg/sessionctx/stmtctx",
"//pkg/types",
"//pkg/util/collate",
"//pkg/util/disk",
"//pkg/util/execdetails",

View File

@ -46,6 +46,7 @@ import (
pmodel "github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/collate"
"github.com/pingcap/tidb/pkg/util/logutil"
tlsutil "github.com/pingcap/tidb/pkg/util/tls"
@ -693,3 +694,20 @@ func createTLSCertificates(certpath string, keypath string, rsaKeySize int) erro
// use RSA and unspecified signature algorithm
return CreateCertificates(certpath, keypath, rsaKeySize, x509.RSA, x509.UnknownSignatureAlgorithm)
}
// GetTypeFlagsForInsert gets the type flags for insert statement.
func GetTypeFlagsForInsert(baseFlags types.Flags, sqlMode mysql.SQLMode, ignoreErr bool) types.Flags {
strictSQLMode := sqlMode.HasStrictMode()
return baseFlags.
WithTruncateAsWarning(!strictSQLMode || ignoreErr).
WithIgnoreInvalidDateErr(sqlMode.HasAllowInvalidDatesMode()).
WithIgnoreZeroInDate(!sqlMode.HasNoZeroInDateMode() ||
!sqlMode.HasNoZeroDateMode() || !strictSQLMode || ignoreErr ||
sqlMode.HasAllowInvalidDatesMode())
}
// GetTypeFlagsForImportInto gets the type flags for import into statement which
// has the same flags as normal `INSERT INTO xxx`.
func GetTypeFlagsForImportInto(baseFlags types.Flags, sqlMode mysql.SQLMode) types.Flags {
return GetTypeFlagsForInsert(baseFlags, sqlMode, false)
}