expression, plan: rewrite builtin funcs setvar, getvar (#4479)

This commit is contained in:
HuaiyuXu
2017-09-11 16:18:33 +08:00
committed by Han Fei
parent d4bfcf0272
commit 0fd962f297
4 changed files with 55 additions and 37 deletions

View File

@ -18,6 +18,7 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)
@ -75,64 +76,73 @@ type setVarFunctionClass struct {
baseFunctionClass
}
func (c *setVarFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) {
err := errors.Trace(c.verifyArgs(args))
bt := &builtinSetVarSig{newBaseBuiltinFunc(args, ctx)}
bt.foldable = false
return bt.setSelf(bt), errors.Trace(err)
func (c *setVarFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) {
if err = errors.Trace(c.verifyArgs(args)); err != nil {
return nil, err
}
bf := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString, tpString)
bf.tp.Flen, bf.foldable = args[1].GetType().Flen, false
// TODO: we should consider the type of the argument, but not take it as string for all situations.
sig = &builtinSetVarSig{baseStringBuiltinFunc{bf}}
return sig.setSelf(sig), errors.Trace(err)
}
type builtinSetVarSig struct {
baseBuiltinFunc
baseStringBuiltinFunc
}
func (b *builtinSetVarSig) eval(row []types.Datum) (types.Datum, error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
func (b *builtinSetVarSig) evalString(row []types.Datum) (res string, isNull bool, err error) {
var varName string
sessionVars := b.ctx.GetSessionVars()
varName, _ := args[0].ToString()
if !args[1].IsNull() {
strVal, err := args[1].ToString()
if err != nil {
return types.Datum{}, errors.Trace(err)
}
sessionVars.UsersLock.Lock()
sessionVars.Users[varName] = strings.ToLower(strVal)
sessionVars.UsersLock.Unlock()
sc := sessionVars.StmtCtx
varName, isNull, err = b.args[0].EvalString(row, sc)
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}
return args[1], nil
res, isNull, err = b.args[1].EvalString(row, sc)
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}
varName = strings.ToLower(varName)
sessionVars.UsersLock.Lock()
sessionVars.Users[varName] = res
sessionVars.UsersLock.Unlock()
return res, false, nil
}
type getVarFunctionClass struct {
baseFunctionClass
}
func (c *getVarFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) {
err := errors.Trace(c.verifyArgs(args))
bt := &builtinGetVarSig{newBaseBuiltinFunc(args, ctx)}
bt.foldable = false
return bt.setSelf(bt), errors.Trace(err)
func (c *getVarFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) {
if err = errors.Trace(c.verifyArgs(args)); err != nil {
return nil, err
}
// TODO: we should consider the type of the argument, but not take it as string for all situations.
bf := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString)
bf.tp.Flen, bf.foldable = mysql.MaxFieldVarCharLength, false
sig = &builtinGetVarSig{baseStringBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}
type builtinGetVarSig struct {
baseBuiltinFunc
baseStringBuiltinFunc
}
func (b *builtinGetVarSig) eval(row []types.Datum) (types.Datum, error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
func (b *builtinGetVarSig) evalString(row []types.Datum) (string, bool, error) {
sessionVars := b.ctx.GetSessionVars()
varName, _ := args[0].ToString()
sc := sessionVars.StmtCtx
varName, isNull, err := b.args[0].EvalString(row, sc)
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}
varName = strings.ToLower(varName)
sessionVars.UsersLock.RLock()
defer sessionVars.UsersLock.RUnlock()
if v, ok := sessionVars.Users[varName]; ok {
return types.NewDatum(v), nil
return v, false, nil
}
return types.Datum{}, nil
return "", true, nil
}
type valuesFunctionClass struct {

View File

@ -16,7 +16,6 @@ package expression
import (
"fmt"
"math"
"strings"
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
@ -123,7 +122,7 @@ func (s *testEvaluatorSuite) TestSetVar(c *C) {
c.Assert(ok, Equals, true)
val, ok := tc.res.(string)
c.Assert(ok, Equals, true)
c.Assert(s.ctx.GetSessionVars().Users[key], Equals, strings.ToLower(val))
c.Assert(s.ctx.GetSessionVars().Users[key], Equals, val)
}
}
}

View File

@ -2378,6 +2378,8 @@ func (s *testIntegrationSuite) TestOtherBuiltin(c *C) {
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b double, c varchar(20), d datetime, e time)")
tk.MustExec("insert into t value(1, 2, 'string', '2017-01-01 12:12:12', '12:12:12')")
// for in
result := tk.MustQuery("select 1 in (a, b, c), 'string' in (a, b, c), '2017-01-01 12:12:12' in (c, d, e), '12:12:12' in (c, d, e) from t")
result.Check(testkit.Rows("1 1 1 1"))
result = tk.MustQuery("select 1 in (null, c), 2 in (null, c) from t")
@ -2389,6 +2391,11 @@ func (s *testIntegrationSuite) TestOtherBuiltin(c *C) {
result = tk.MustQuery(`select bit_count(121), bit_count(-1), bit_count(null), bit_count("1231aaa");`)
result.Check(testkit.Rows("5 64 <nil> 7"))
// for setvar, getvar
tk.MustExec(`set @varname = "Abc"`)
result = tk.MustQuery(`select @varname, @VARNAME`)
result.Check(testkit.Rows("Abc Abc"))
}
func (s *testIntegrationSuite) TestDateBuiltin(c *C) {

View File

@ -1125,6 +1125,8 @@ func (s *testPlanSuite) createTestCase4OtherFuncs() []typeInferTestCase {
{"bit_count(c_blob_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0},
{"bit_count(c_set )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0},
{"bit_count(c_enum )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 2, 0},
{`@varname`, mysql.TypeVarString, charset.CharsetUTF8, 0, mysql.MaxFieldVarCharLength, types.UnspecifiedFsp},
}
}