expression, plan: rewrite builtin funcs setvar, getvar (#4479)
This commit is contained in:
@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
@ -75,64 +76,73 @@ type setVarFunctionClass struct {
|
||||
baseFunctionClass
|
||||
}
|
||||
|
||||
func (c *setVarFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) {
|
||||
err := errors.Trace(c.verifyArgs(args))
|
||||
bt := &builtinSetVarSig{newBaseBuiltinFunc(args, ctx)}
|
||||
bt.foldable = false
|
||||
return bt.setSelf(bt), errors.Trace(err)
|
||||
func (c *setVarFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) {
|
||||
if err = errors.Trace(c.verifyArgs(args)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bf := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString, tpString)
|
||||
bf.tp.Flen, bf.foldable = args[1].GetType().Flen, false
|
||||
// TODO: we should consider the type of the argument, but not take it as string for all situations.
|
||||
sig = &builtinSetVarSig{baseStringBuiltinFunc{bf}}
|
||||
return sig.setSelf(sig), errors.Trace(err)
|
||||
}
|
||||
|
||||
type builtinSetVarSig struct {
|
||||
baseBuiltinFunc
|
||||
baseStringBuiltinFunc
|
||||
}
|
||||
|
||||
func (b *builtinSetVarSig) eval(row []types.Datum) (types.Datum, error) {
|
||||
args, err := b.evalArgs(row)
|
||||
if err != nil {
|
||||
return types.Datum{}, errors.Trace(err)
|
||||
}
|
||||
func (b *builtinSetVarSig) evalString(row []types.Datum) (res string, isNull bool, err error) {
|
||||
var varName string
|
||||
sessionVars := b.ctx.GetSessionVars()
|
||||
varName, _ := args[0].ToString()
|
||||
if !args[1].IsNull() {
|
||||
strVal, err := args[1].ToString()
|
||||
if err != nil {
|
||||
return types.Datum{}, errors.Trace(err)
|
||||
}
|
||||
sessionVars.UsersLock.Lock()
|
||||
sessionVars.Users[varName] = strings.ToLower(strVal)
|
||||
sessionVars.UsersLock.Unlock()
|
||||
sc := sessionVars.StmtCtx
|
||||
varName, isNull, err = b.args[0].EvalString(row, sc)
|
||||
if isNull || err != nil {
|
||||
return "", isNull, errors.Trace(err)
|
||||
}
|
||||
return args[1], nil
|
||||
res, isNull, err = b.args[1].EvalString(row, sc)
|
||||
if isNull || err != nil {
|
||||
return "", isNull, errors.Trace(err)
|
||||
}
|
||||
varName = strings.ToLower(varName)
|
||||
sessionVars.UsersLock.Lock()
|
||||
sessionVars.Users[varName] = res
|
||||
sessionVars.UsersLock.Unlock()
|
||||
return res, false, nil
|
||||
}
|
||||
|
||||
type getVarFunctionClass struct {
|
||||
baseFunctionClass
|
||||
}
|
||||
|
||||
func (c *getVarFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) {
|
||||
err := errors.Trace(c.verifyArgs(args))
|
||||
bt := &builtinGetVarSig{newBaseBuiltinFunc(args, ctx)}
|
||||
bt.foldable = false
|
||||
return bt.setSelf(bt), errors.Trace(err)
|
||||
func (c *getVarFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) {
|
||||
if err = errors.Trace(c.verifyArgs(args)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: we should consider the type of the argument, but not take it as string for all situations.
|
||||
bf := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString)
|
||||
bf.tp.Flen, bf.foldable = mysql.MaxFieldVarCharLength, false
|
||||
sig = &builtinGetVarSig{baseStringBuiltinFunc{bf}}
|
||||
return sig.setSelf(sig), nil
|
||||
}
|
||||
|
||||
type builtinGetVarSig struct {
|
||||
baseBuiltinFunc
|
||||
baseStringBuiltinFunc
|
||||
}
|
||||
|
||||
func (b *builtinGetVarSig) eval(row []types.Datum) (types.Datum, error) {
|
||||
args, err := b.evalArgs(row)
|
||||
if err != nil {
|
||||
return types.Datum{}, errors.Trace(err)
|
||||
}
|
||||
func (b *builtinGetVarSig) evalString(row []types.Datum) (string, bool, error) {
|
||||
sessionVars := b.ctx.GetSessionVars()
|
||||
varName, _ := args[0].ToString()
|
||||
sc := sessionVars.StmtCtx
|
||||
varName, isNull, err := b.args[0].EvalString(row, sc)
|
||||
if isNull || err != nil {
|
||||
return "", isNull, errors.Trace(err)
|
||||
}
|
||||
varName = strings.ToLower(varName)
|
||||
sessionVars.UsersLock.RLock()
|
||||
defer sessionVars.UsersLock.RUnlock()
|
||||
if v, ok := sessionVars.Users[varName]; ok {
|
||||
return types.NewDatum(v), nil
|
||||
return v, false, nil
|
||||
}
|
||||
return types.Datum{}, nil
|
||||
return "", true, nil
|
||||
}
|
||||
|
||||
type valuesFunctionClass struct {
|
||||
|
||||
@ -16,7 +16,6 @@ package expression
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
@ -123,7 +122,7 @@ func (s *testEvaluatorSuite) TestSetVar(c *C) {
|
||||
c.Assert(ok, Equals, true)
|
||||
val, ok := tc.res.(string)
|
||||
c.Assert(ok, Equals, true)
|
||||
c.Assert(s.ctx.GetSessionVars().Users[key], Equals, strings.ToLower(val))
|
||||
c.Assert(s.ctx.GetSessionVars().Users[key], Equals, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2378,6 +2378,8 @@ func (s *testIntegrationSuite) TestOtherBuiltin(c *C) {
|
||||
tk.MustExec("drop table if exists t")
|
||||
tk.MustExec("create table t(a int, b double, c varchar(20), d datetime, e time)")
|
||||
tk.MustExec("insert into t value(1, 2, 'string', '2017-01-01 12:12:12', '12:12:12')")
|
||||
|
||||
// for in
|
||||
result := tk.MustQuery("select 1 in (a, b, c), 'string' in (a, b, c), '2017-01-01 12:12:12' in (c, d, e), '12:12:12' in (c, d, e) from t")
|
||||
result.Check(testkit.Rows("1 1 1 1"))
|
||||
result = tk.MustQuery("select 1 in (null, c), 2 in (null, c) from t")
|
||||
@ -2389,6 +2391,11 @@ func (s *testIntegrationSuite) TestOtherBuiltin(c *C) {
|
||||
|
||||
result = tk.MustQuery(`select bit_count(121), bit_count(-1), bit_count(null), bit_count("1231aaa");`)
|
||||
result.Check(testkit.Rows("5 64 <nil> 7"))
|
||||
|
||||
// for setvar, getvar
|
||||
tk.MustExec(`set @varname = "Abc"`)
|
||||
result = tk.MustQuery(`select @varname, @VARNAME`)
|
||||
result.Check(testkit.Rows("Abc Abc"))
|
||||
}
|
||||
|
||||
func (s *testIntegrationSuite) TestDateBuiltin(c *C) {
|
||||
|
||||
@ -1125,6 +1125,8 @@ func (s *testPlanSuite) createTestCase4OtherFuncs() []typeInferTestCase {
|
||||
{"bit_count(c_blob_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0},
|
||||
{"bit_count(c_set )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0},
|
||||
{"bit_count(c_enum )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0},
|
||||
|
||||
{`@varname`, mysql.TypeVarString, charset.CharsetUTF8, 0, mysql.MaxFieldVarCharLength, types.UnspecifiedFsp},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user