diff --git a/expression/builtin_other.go b/expression/builtin_other.go index 0abdc14ea8..b4c51a2653 100644 --- a/expression/builtin_other.go +++ b/expression/builtin_other.go @@ -18,6 +18,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) @@ -75,64 +76,73 @@ type setVarFunctionClass struct { baseFunctionClass } -func (c *setVarFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) { - err := errors.Trace(c.verifyArgs(args)) - bt := &builtinSetVarSig{newBaseBuiltinFunc(args, ctx)} - bt.foldable = false - return bt.setSelf(bt), errors.Trace(err) +func (c *setVarFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) { + if err = errors.Trace(c.verifyArgs(args)); err != nil { + return nil, err + } + bf := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString, tpString) + bf.tp.Flen, bf.foldable = args[1].GetType().Flen, false + // TODO: we should consider the type of the argument, but not take it as string for all situations. + sig = &builtinSetVarSig{baseStringBuiltinFunc{bf}} + return sig.setSelf(sig), errors.Trace(err) } type builtinSetVarSig struct { - baseBuiltinFunc + baseStringBuiltinFunc } -func (b *builtinSetVarSig) eval(row []types.Datum) (types.Datum, error) { - args, err := b.evalArgs(row) - if err != nil { - return types.Datum{}, errors.Trace(err) - } +func (b *builtinSetVarSig) evalString(row []types.Datum) (res string, isNull bool, err error) { + var varName string sessionVars := b.ctx.GetSessionVars() - varName, _ := args[0].ToString() - if !args[1].IsNull() { - strVal, err := args[1].ToString() - if err != nil { - return types.Datum{}, errors.Trace(err) - } - sessionVars.UsersLock.Lock() - sessionVars.Users[varName] = strings.ToLower(strVal) - sessionVars.UsersLock.Unlock() + sc := sessionVars.StmtCtx + varName, isNull, err = b.args[0].EvalString(row, sc) + if isNull || err != nil { + return "", isNull, errors.Trace(err) } - return args[1], nil + res, isNull, err = b.args[1].EvalString(row, sc) + if isNull || err != nil { + return "", isNull, errors.Trace(err) + } + varName = strings.ToLower(varName) + sessionVars.UsersLock.Lock() + sessionVars.Users[varName] = res + sessionVars.UsersLock.Unlock() + return res, false, nil } type getVarFunctionClass struct { baseFunctionClass } -func (c *getVarFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) { - err := errors.Trace(c.verifyArgs(args)) - bt := &builtinGetVarSig{newBaseBuiltinFunc(args, ctx)} - bt.foldable = false - return bt.setSelf(bt), errors.Trace(err) +func (c *getVarFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) { + if err = errors.Trace(c.verifyArgs(args)); err != nil { + return nil, err + } + // TODO: we should consider the type of the argument, but not take it as string for all situations. + bf := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString) + bf.tp.Flen, bf.foldable = mysql.MaxFieldVarCharLength, false + sig = &builtinGetVarSig{baseStringBuiltinFunc{bf}} + return sig.setSelf(sig), nil } type builtinGetVarSig struct { - baseBuiltinFunc + baseStringBuiltinFunc } -func (b *builtinGetVarSig) eval(row []types.Datum) (types.Datum, error) { - args, err := b.evalArgs(row) - if err != nil { - return types.Datum{}, errors.Trace(err) - } +func (b *builtinGetVarSig) evalString(row []types.Datum) (string, bool, error) { sessionVars := b.ctx.GetSessionVars() - varName, _ := args[0].ToString() + sc := sessionVars.StmtCtx + varName, isNull, err := b.args[0].EvalString(row, sc) + if isNull || err != nil { + return "", isNull, errors.Trace(err) + } + varName = strings.ToLower(varName) sessionVars.UsersLock.RLock() defer sessionVars.UsersLock.RUnlock() if v, ok := sessionVars.Users[varName]; ok { - return types.NewDatum(v), nil + return v, false, nil } - return types.Datum{}, nil + return "", true, nil } type valuesFunctionClass struct { diff --git a/expression/builtin_other_test.go b/expression/builtin_other_test.go index 9bee104aec..8d0f8d9e82 100644 --- a/expression/builtin_other_test.go +++ b/expression/builtin_other_test.go @@ -16,7 +16,6 @@ package expression import ( "fmt" "math" - "strings" . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" @@ -123,7 +122,7 @@ func (s *testEvaluatorSuite) TestSetVar(c *C) { c.Assert(ok, Equals, true) val, ok := tc.res.(string) c.Assert(ok, Equals, true) - c.Assert(s.ctx.GetSessionVars().Users[key], Equals, strings.ToLower(val)) + c.Assert(s.ctx.GetSessionVars().Users[key], Equals, val) } } } diff --git a/expression/integration_test.go b/expression/integration_test.go index 2cd2655cb8..a4e7fccce1 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2378,6 +2378,8 @@ func (s *testIntegrationSuite) TestOtherBuiltin(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int, b double, c varchar(20), d datetime, e time)") tk.MustExec("insert into t value(1, 2, 'string', '2017-01-01 12:12:12', '12:12:12')") + + // for in result := tk.MustQuery("select 1 in (a, b, c), 'string' in (a, b, c), '2017-01-01 12:12:12' in (c, d, e), '12:12:12' in (c, d, e) from t") result.Check(testkit.Rows("1 1 1 1")) result = tk.MustQuery("select 1 in (null, c), 2 in (null, c) from t") @@ -2389,6 +2391,11 @@ func (s *testIntegrationSuite) TestOtherBuiltin(c *C) { result = tk.MustQuery(`select bit_count(121), bit_count(-1), bit_count(null), bit_count("1231aaa");`) result.Check(testkit.Rows("5 64 7")) + + // for setvar, getvar + tk.MustExec(`set @varname = "Abc"`) + result = tk.MustQuery(`select @varname, @VARNAME`) + result.Check(testkit.Rows("Abc Abc")) } func (s *testIntegrationSuite) TestDateBuiltin(c *C) { diff --git a/plan/typeinfer_test.go b/plan/typeinfer_test.go index 32bfb3b547..be8312c087 100644 --- a/plan/typeinfer_test.go +++ b/plan/typeinfer_test.go @@ -1125,6 +1125,8 @@ func (s *testPlanSuite) createTestCase4OtherFuncs() []typeInferTestCase { {"bit_count(c_blob_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0}, {"bit_count(c_set )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0}, {"bit_count(c_enum )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0}, + + {`@varname`, mysql.TypeVarString, charset.CharsetUTF8, 0, mysql.MaxFieldVarCharLength, types.UnspecifiedFsp}, } }