From e61ee664f5e6a57facbdc760be3ba55327dfd2fd Mon Sep 17 00:00:00 2001 From: YangKeao Date: Fri, 22 Dec 2023 15:05:16 +0800 Subject: [PATCH] errctx, stmtctx: add cache to the errctx (#49689) close pingcap/tidb#49688 --- pkg/sessionctx/stmtctx/BUILD.bazel | 2 +- pkg/sessionctx/stmtctx/stmtctx.go | 13 +++++++++---- pkg/sessionctx/stmtctx/stmtctx_test.go | 11 +++++++++++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/pkg/sessionctx/stmtctx/BUILD.bazel b/pkg/sessionctx/stmtctx/BUILD.bazel index 20ecfbdc9f..c8726e67cd 100644 --- a/pkg/sessionctx/stmtctx/BUILD.bazel +++ b/pkg/sessionctx/stmtctx/BUILD.bazel @@ -41,7 +41,7 @@ go_test( ], embed = [":stmtctx"], flaky = True, - shard_count = 11, + shard_count = 12, deps = [ "//pkg/kv", "//pkg/sessionctx/variable", diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 6ce1ea2175..a114252232 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -442,6 +442,7 @@ func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext { ctxID: stmtCtxIDGenerator.Add(1), } sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc) + sc.initErrCtx() return sc } @@ -451,6 +452,7 @@ func (sc *StatementContext) Reset() { ctxID: stmtCtxIDGenerator.Add(1), typeCtx: types.NewContext(types.DefaultStmtFlags, time.UTC, sc), } + sc.initErrCtx() } // CtxID returns the context id of the statement @@ -479,9 +481,7 @@ func (sc *StatementContext) TypeCtx() types.Context { return sc.typeCtx } -// ErrCtx returns the error context -// TODO: add a cache to the `ErrCtx` if needed, though it's not a big burden to generate `ErrCtx` everytime. -func (sc *StatementContext) ErrCtx() errctx.Context { +func (sc *StatementContext) initErrCtx() { ctx := errctx.NewContext(sc) if sc.TypeFlags().IgnoreTruncateErr() { @@ -489,8 +489,12 @@ func (sc *StatementContext) ErrCtx() errctx.Context { } else if sc.TypeFlags().TruncateAsWarning() { ctx = ctx.WithErrGroupLevel(errctx.ErrGroupTruncate, errctx.LevelWarn) } + sc.errCtx = ctx +} - return ctx +// ErrCtx returns the error context +func (sc *StatementContext) ErrCtx() errctx.Context { + return sc.errCtx } // TypeFlags returns the type flags @@ -501,6 +505,7 @@ func (sc *StatementContext) TypeFlags() types.Flags { // SetTypeFlags sets the type flags func (sc *StatementContext) SetTypeFlags(flags types.Flags) { sc.typeCtx = sc.typeCtx.WithFlags(flags) + sc.initErrCtx() } // HandleTruncate ignores or returns the error based on the TypeContext inside. diff --git a/pkg/sessionctx/stmtctx/stmtctx_test.go b/pkg/sessionctx/stmtctx/stmtctx_test.go index a93a269915..5aa6fcad8f 100644 --- a/pkg/sessionctx/stmtctx/stmtctx_test.go +++ b/pkg/sessionctx/stmtctx/stmtctx_test.go @@ -408,6 +408,17 @@ func TestStmtCtxID(t *testing.T) { } } +func TestErrCtx(t *testing.T) { + sc := stmtctx.NewStmtCtx() + // the default errCtx + err := types.ErrTruncated + require.Error(t, sc.HandleError(err)) + + // reset the types flags will re-initialize the error flag + sc.SetTypeFlags(types.DefaultStmtFlags | types.FlagTruncateAsWarning) + require.NoError(t, sc.HandleError(err)) +} + func BenchmarkErrCtx(b *testing.B) { sc := stmtctx.NewStmtCtx()