expression: fix type infer for tidb's builtin compare(least and greatest) (#21150)

Signed-off-by: iosmanthus <myosmanthustree@gmail.com>
This commit is contained in:
Iosmanthus Teng
2020-12-22 14:58:31 +08:00
committed by GitHub
parent cf806f60e4
commit dd0dc46d5e
8 changed files with 173 additions and 134 deletions

View File

@ -15,6 +15,7 @@ package expression
import (
"math"
"strings"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
@ -367,53 +368,67 @@ func (b *builtinCoalesceJSONSig) evalJSON(row chunk.Row) (res json.BinaryJSON, i
return res, isNull, err
}
// temporalWithDateAsNumEvalType makes DATE, DATETIME, TIMESTAMP pretend to be numbers rather than strings.
func temporalWithDateAsNumEvalType(argTp *types.FieldType) (argEvalType types.EvalType, isStr bool, isTemporalWithDate bool) {
argEvalType = argTp.EvalType()
isStr, isTemporalWithDate = argEvalType.IsStringKind(), types.IsTemporalWithDate(argTp.Tp)
if !isTemporalWithDate {
return
func aggregateType(args []Expression) *types.FieldType {
fieldTypes := make([]*types.FieldType, len(args))
for i := range fieldTypes {
fieldTypes[i] = args[i].GetType()
}
if argTp.Decimal > 0 {
argEvalType = types.ETDecimal
} else {
argEvalType = types.ETInt
}
return
return types.AggFieldType(fieldTypes)
}
// GetCmpTp4MinMax gets compare type for GREATEST and LEAST and BETWEEN
func GetCmpTp4MinMax(args []Expression) (argTp types.EvalType) {
datetimeFound, isAllStr := false, true
cmpEvalType, isStr, isTemporalWithDate := temporalWithDateAsNumEvalType(args[0].GetType())
if !isStr {
isAllStr = false
// ResolveType4Between resolves eval type for between expression.
func ResolveType4Between(args [3]Expression) types.EvalType {
cmpTp := args[0].GetType().EvalType()
for i := 1; i < 3; i++ {
cmpTp = getBaseCmpType(cmpTp, args[i].GetType().EvalType(), nil, nil)
}
if isTemporalWithDate {
datetimeFound = true
}
lft := args[0].GetType()
for i := range args {
rft := args[i].GetType()
var tp types.EvalType
tp, isStr, isTemporalWithDate = temporalWithDateAsNumEvalType(rft)
if isTemporalWithDate {
datetimeFound = true
hasTemporal := false
if cmpTp == types.ETString {
for _, arg := range args {
if types.IsTypeTemporal(arg.GetType().Tp) {
hasTemporal = true
break
}
}
if !isStr {
isAllStr = false
if hasTemporal {
cmpTp = types.ETDatetime
}
cmpEvalType = getBaseCmpType(cmpEvalType, tp, lft, rft)
lft = rft
}
argTp = cmpEvalType
if cmpEvalType.IsStringKind() {
argTp = types.ETString
return cmpTp
}
// resolveType4Extremum gets compare type for GREATEST and LEAST and BETWEEN (mainly for datetime).
func resolveType4Extremum(args []Expression) types.EvalType {
aggType := aggregateType(args)
var temporalItem *types.FieldType
if aggType.EvalType().IsStringKind() {
for i := range args {
item := args[i].GetType()
if types.IsTemporalWithDate(item.Tp) {
temporalItem = item
}
}
if !types.IsTemporalWithDate(aggType.Tp) && temporalItem != nil {
aggType.Tp = temporalItem.Tp
}
// TODO: String charset, collation checking are needed.
}
if isAllStr && datetimeFound {
argTp = types.ETDatetime
return aggType.EvalType()
}
// unsupportedJSONComparison reports warnings while there is a JSON type in least/greatest function's arguments
func unsupportedJSONComparison(ctx sessionctx.Context, args []Expression) {
for _, arg := range args {
tp := arg.GetType().Tp
if tp == mysql.TypeJSON {
ctx.GetSessionVars().StmtCtx.AppendWarning(errUnsupportedJSONComparison)
break
}
}
return argTp
}
type greatestFunctionClass struct {
@ -424,10 +439,14 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp, cmpAsDatetime := GetCmpTp4MinMax(args), false
if tp == types.ETDatetime {
tp := resolveType4Extremum(args)
cmpAsDatetime := false
if tp == types.ETDatetime || tp == types.ETTimestamp {
cmpAsDatetime = true
tp = types.ETString
} else if tp == types.ETJson {
unsupportedJSONComparison(ctx, args)
tp = types.ETString
}
argTps := make([]types.EvalType, len(args))
for i := range args {
@ -453,7 +472,7 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
case types.ETString:
sig = &builtinGreatestStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestString)
case types.ETDatetime:
case types.ETDatetime, types.ETTimestamp:
sig = &builtinGreatestTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestTime)
}
@ -592,30 +611,39 @@ func (b *builtinGreatestTimeSig) Clone() builtinFunc {
// evalString evals a builtinGreatestTimeSig.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_greatest
func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (_ string, isNull bool, err error) {
func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (res string, isNull bool, err error) {
var (
v string
t types.Time
strRes string
timeRes types.Time
)
max := types.ZeroDatetime
sc := b.ctx.GetSessionVars().StmtCtx
for i := 0; i < len(b.args); i++ {
v, isNull, err = b.args[i].EvalString(b.ctx, row)
v, isNull, err := b.args[i].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, err
}
t, err = types.ParseDatetime(sc, v)
t, err := types.ParseDatetime(sc, v)
if err != nil {
if err = handleInvalidTimeError(b.ctx, err); err != nil {
return v, true, err
}
continue
} else {
v = t.String()
}
if t.Compare(max) > 0 {
max = t
// In MySQL, if the compare result is zero, than we will try to use the string comparison result
if i == 0 || strings.Compare(v, strRes) > 0 {
strRes = v
}
if i == 0 || t.Compare(timeRes) > 0 {
timeRes = t
}
}
return max.String(), false, nil
if timeRes.IsZero() {
res = strRes
} else {
res = timeRes.String()
}
return res, false, nil
}
type leastFunctionClass struct {
@ -626,10 +654,14 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp, cmpAsDatetime := GetCmpTp4MinMax(args), false
tp := resolveType4Extremum(args)
cmpAsDatetime := false
if tp == types.ETDatetime {
cmpAsDatetime = true
tp = types.ETString
} else if tp == types.ETJson {
unsupportedJSONComparison(ctx, args)
tp = types.ETString
}
argTps := make([]types.EvalType, len(args))
for i := range args {
@ -796,32 +828,36 @@ func (b *builtinLeastTimeSig) Clone() builtinFunc {
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#functionleast
func (b *builtinLeastTimeSig) evalString(row chunk.Row) (res string, isNull bool, err error) {
var (
v string
t types.Time
// timeRes will be converted to a strRes only when the arguments is a valid datetime value.
strRes string // Record the strRes of each arguments.
timeRes types.Time // Record the time representation of a valid arguments.
)
min := types.NewTime(types.MaxDatetime, mysql.TypeDatetime, types.MaxFsp)
findInvalidTime := false
sc := b.ctx.GetSessionVars().StmtCtx
for i := 0; i < len(b.args); i++ {
v, isNull, err = b.args[i].EvalString(b.ctx, row)
v, isNull, err := b.args[i].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, err
}
t, err = types.ParseDatetime(sc, v)
t, err := types.ParseDatetime(sc, v)
if err != nil {
if err = handleInvalidTimeError(b.ctx, err); err != nil {
return v, true, err
} else if !findInvalidTime {
res = v
findInvalidTime = true
}
} else {
v = t.String()
}
if t.Compare(min) < 0 {
min = t
if i == 0 || strings.Compare(v, strRes) < 0 {
strRes = v
}
if i == 0 || t.Compare(timeRes) < 0 {
timeRes = t
}
}
if !findInvalidTime {
res = min.String()
if timeRes.IsZero() {
res = strRes
} else {
res = timeRes.String()
}
return res, false, nil
}
@ -1042,7 +1078,7 @@ type compareFunctionClass struct {
// getBaseCmpType gets the EvalType that the two args will be treated as when comparing.
func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.EvalType {
if lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified {
if lft != nil && rft != nil && (lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified) {
if lft.Tp == rft.Tp {
return types.ETString
}
@ -1054,13 +1090,13 @@ func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.Ev
}
if lhs.IsStringKind() && rhs.IsStringKind() {
return types.ETString
} else if (lhs == types.ETInt || lft.Hybrid()) && (rhs == types.ETInt || rft.Hybrid()) {
} else if (lhs == types.ETInt || (lft != nil && lft.Hybrid())) && (rhs == types.ETInt || (rft != nil && rft.Hybrid())) {
return types.ETInt
} else if ((lhs == types.ETInt || lft.Hybrid()) || lhs == types.ETDecimal) &&
((rhs == types.ETInt || rft.Hybrid()) || rhs == types.ETDecimal) {
} else if ((lhs == types.ETInt || (lft != nil && lft.Hybrid())) || lhs == types.ETDecimal) &&
((rhs == types.ETInt || (rft != nil && rft.Hybrid())) || rhs == types.ETDecimal) {
return types.ETDecimal
} else if types.IsTemporalWithDate(lft.Tp) && rft.Tp == mysql.TypeYear ||
lft.Tp == mysql.TypeYear && types.IsTemporalWithDate(rft.Tp) {
} else if lft != nil && rft != nil && (types.IsTemporalWithDate(lft.Tp) && rft.Tp == mysql.TypeYear ||
lft.Tp == mysql.TypeYear && types.IsTemporalWithDate(rft.Tp)) {
return types.ETDatetime
}
return types.ETReal

View File

@ -258,7 +258,8 @@ func (s *testEvaluatorSuite) TestIntervalFunc(c *C) {
}
}
func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) {
// greatest/least function is compatible with MySQL 8.0
func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) {
sc := s.ctx.GetSessionVars().StmtCtx
originIgnoreTruncate := sc.IgnoreTruncate
sc.IgnoreTruncate = true
@ -283,7 +284,7 @@ func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) {
},
{
[]interface{}{"123a", "b", "c", 12},
float64(123), float64(0), false, false,
"c", "12", false, false,
},
{
[]interface{}{tm, "123"},
@ -291,15 +292,15 @@ func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) {
},
{
[]interface{}{tm, 123},
curTimeInt, int64(123), false, false,
curTimeString, "123", false, false,
},
{
[]interface{}{tm, "invalid_time_1", "invalid_time_2", tmWithFsp},
curTimeWithFspString, "invalid_time_1", false, false,
curTimeWithFspString, curTimeString, false, false,
},
{
[]interface{}{tm, "invalid_time_2", "invalid_time_1", tmWithFsp},
curTimeWithFspString, "invalid_time_2", false, false,
curTimeWithFspString, curTimeString, false, false,
},
{
[]interface{}{tm, "invalid_time", nil, tmWithFsp},
@ -317,6 +318,14 @@ func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) {
[]interface{}{errors.New("must error"), 123},
nil, nil, false, true,
},
{
[]interface{}{794755072.0, 4556, "2000-01-09"},
"794755072", "2000-01-09", false, false,
},
{
[]interface{}{905969664.0, 4556, "1990-06-16 17:22:56.005534"},
"905969664", "1990-06-16 17:22:56.005534", false, false,
},
} {
f0, err := newFunctionForTest(s.ctx, ast.Greatest, s.primitiveValsToConstants(t.args)...)
c.Assert(err, IsNil)

View File

@ -14,6 +14,8 @@
package expression
import (
"strings"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
@ -633,47 +635,46 @@ func (b *builtinGreatestTimeSig) vectorized() bool {
}
func (b *builtinGreatestTimeSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
dst, err := b.bufAllocator.get(types.ETTimestamp, n)
if err != nil {
return err
}
defer b.bufAllocator.put(dst)
sc := b.ctx.GetSessionVars().StmtCtx
dst.ResizeTime(n, false)
dstTimes := dst.Times()
for i := 0; i < n; i++ {
dstTimes[i] = types.ZeroDatetime
}
var argTime types.Time
n := input.NumRows()
dstStrings := make([]string, n)
// TODO: use Column.MergeNulls instead, however, it doesn't support var-length type currently.
dstNullMap := make([]bool, n)
for j := 0; j < len(b.args); j++ {
if err := b.args[j].VecEvalString(b.ctx, input, result); err != nil {
return err
}
for i := 0; i < n; i++ {
if result.IsNull(i) || dst.IsNull(i) {
dst.SetNull(i, true)
if dstNullMap[i] = dstNullMap[i] || result.IsNull(i); dstNullMap[i] {
continue
}
argTime, err = types.ParseDatetime(sc, result.GetString(i))
// NOTE: can't use Column.GetString because it returns an unsafe string, copy the row instead.
argTimeStr := string(result.GetBytes(i))
argTime, err := types.ParseDatetime(sc, argTimeStr)
if err != nil {
if err = handleInvalidTimeError(b.ctx, err); err != nil {
return err
}
continue
} else {
argTimeStr = argTime.String()
}
if argTime.Compare(dstTimes[i]) > 0 {
dstTimes[i] = argTime
if j == 0 || strings.Compare(argTimeStr, dstStrings[i]) > 0 {
dstStrings[i] = argTimeStr
}
}
}
// Aggregate the NULL and String value into result
result.ReserveString(n)
for i := 0; i < n; i++ {
if dst.IsNull(i) {
if dstNullMap[i] {
result.AppendNull()
} else {
result.AppendString(dstTimes[i].String())
result.AppendString(dstStrings[i])
}
}
return nil
@ -719,60 +720,46 @@ func (b *builtinLeastTimeSig) vectorized() bool {
}
func (b *builtinLeastTimeSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
dst, err := b.bufAllocator.get(types.ETTimestamp, n)
if err != nil {
return err
}
defer b.bufAllocator.put(dst)
sc := b.ctx.GetSessionVars().StmtCtx
dst.ResizeTime(n, false)
dstTimes := dst.Times()
for i := 0; i < n; i++ {
dstTimes[i] = types.NewTime(types.MaxDatetime, mysql.TypeDatetime, types.DefaultFsp)
}
var argTime types.Time
n := input.NumRows()
findInvalidTime := make([]bool, n)
invalidValue := make([]string, n)
dstStrings := make([]string, n)
// TODO: use Column.MergeNulls instead, however, it doesn't support var-length type currently.
dstNullMap := make([]bool, n)
for j := 0; j < len(b.args); j++ {
if err := b.args[j].VecEvalString(b.ctx, input, result); err != nil {
return err
}
dst.MergeNulls(result)
for i := 0; i < n; i++ {
if dst.IsNull(i) {
if dstNullMap[i] = dstNullMap[i] || result.IsNull(i); dstNullMap[i] {
continue
}
argTime, err = types.ParseDatetime(sc, result.GetString(i))
// NOTE: can't use Column.GetString because it returns an unsafe string, copy the row instead.
argTimeStr := string(result.GetBytes(i))
argTime, err := types.ParseDatetime(sc, argTimeStr)
if err != nil {
if err = handleInvalidTimeError(b.ctx, err); err != nil {
return err
} else if !findInvalidTime[i] {
// Make a deep copy here.
// Otherwise invalidValue will internally change with result.
invalidValue[i] = string(result.GetBytes(i))
findInvalidTime[i] = true
}
continue
} else {
argTimeStr = argTime.String()
}
if argTime.Compare(dstTimes[i]) < 0 {
dstTimes[i] = argTime
if j == 0 || strings.Compare(argTimeStr, dstStrings[i]) < 0 {
dstStrings[i] = argTimeStr
}
}
}
// Aggregate the NULL and String value into result
result.ReserveString(n)
for i := 0; i < n; i++ {
if dst.IsNull(i) {
if dstNullMap[i] {
result.AppendNull()
continue
}
if findInvalidTime[i] {
result.AppendString(invalidValue[i])
} else {
result.AppendString(dstTimes[i].String())
result.AppendString(dstStrings[i])
}
}
return nil

View File

@ -50,7 +50,9 @@ var (
errNonUniq = dbterror.ClassExpression.NewStd(mysql.ErrNonUniq)
// Sequence usage privilege check.
errSequenceAccessDenied = dbterror.ClassExpression.NewStd(mysql.ErrTableaccessDenied)
errSequenceAccessDenied = dbterror.ClassExpression.NewStd(mysql.ErrTableaccessDenied)
errUnsupportedJSONComparison = dbterror.ClassExpression.NewStdErr(mysql.ErrNotSupportedYet,
pmysql.Message("comparison of JSON in the LEAST and GREATEST operators", nil))
)
// handleInvalidTimeError reports error or warning depend on the context.

View File

@ -3712,8 +3712,8 @@ func (s *testIntegrationSuite) TestCompareBuiltin(c *C) {
// for greatest
result = tk.MustQuery(`select greatest(1, 2, 3), greatest("a", "b", "c"), greatest(1.1, 1.2, 1.3), greatest("123a", 1, 2)`)
result.Check(testkit.Rows("3 c 1.3 123"))
tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect FLOAT value: '123a'"))
result.Check(testkit.Rows("3 c 1.3 2"))
tk.MustQuery("show warnings").Check(testkit.Rows())
result = tk.MustQuery(`select greatest(cast("2017-01-01" as datetime), "123", "234", cast("2018-01-01" as date)), greatest(cast("2017-01-01" as date), "123", null)`)
// todo: MySQL returns "2018-01-01 <nil>"
result.Check(testkit.Rows("2018-01-01 00:00:00 <nil>"))
@ -3721,7 +3721,7 @@ func (s *testIntegrationSuite) TestCompareBuiltin(c *C) {
// for least
result = tk.MustQuery(`select least(1, 2, 3), least("a", "b", "c"), least(1.1, 1.2, 1.3), least("123a", 1, 2)`)
result.Check(testkit.Rows("1 a 1.1 1"))
tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect FLOAT value: '123a'"))
tk.MustQuery("show warnings").Check(testkit.Rows())
result = tk.MustQuery(`select least(cast("2017-01-01" as datetime), "123", "234", cast("2018-01-01" as date)), least(cast("2017-01-01" as date), "123", null)`)
result.Check(testkit.Rows("123 <nil>"))
tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Incorrect time value: '123'", "Warning|1292|Incorrect time value: '234'", "Warning|1292|Incorrect time value: '123'"))

View File

@ -1545,7 +1545,7 @@ func (er *expressionRewriter) wrapExpWithCast() (expr, lexp, rexp expression.Exp
stkLen := len(er.ctxStack)
expr, lexp, rexp = er.ctxStack[stkLen-3], er.ctxStack[stkLen-2], er.ctxStack[stkLen-1]
var castFunc func(sessionctx.Context, expression.Expression) expression.Expression
switch expression.GetCmpTp4MinMax([]expression.Expression{expr, lexp, rexp}) {
switch expression.ResolveType4Between([3]expression.Expression{expr, lexp, rexp}) {
case types.ETInt:
castFunc = expression.WrapWithCastAsInt
case types.ETReal:

View File

@ -72,6 +72,7 @@ func (s *testExpressionSuite) TestBetween(c *C) {
{exprStr: "1 not between 2 and 3", resultStr: "1"},
{exprStr: "'2001-04-10 12:34:56' between cast('2001-01-01 01:01:01' as datetime) and '01-05-01'", resultStr: "1"},
{exprStr: "20010410123456 between cast('2001-01-01 01:01:01' as datetime) and 010501", resultStr: "0"},
{exprStr: "20010410123456 between cast('2001-01-01 01:01:01' as datetime) and 20010501123456", resultStr: "1"},
}
s.runTests(c, tests)
}

View File

@ -99,6 +99,10 @@ func AggFieldType(tps []*FieldType) *FieldType {
}
}
if mysql.HasUnsignedFlag(currType.Flag) && !isMixedSign {
currType.Flag |= mysql.UnsignedFlag
}
return &currType
}