817 lines
22 KiB
Go
817 lines
22 KiB
Go
// Copyright 2017 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package expression
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pingcap/tidb/pkg/expression/exprstatic"
|
|
"github.com/pingcap/tidb/pkg/parser/ast"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/testkit/testutil"
|
|
"github.com/pingcap/tidb/pkg/types"
|
|
"github.com/pingcap/tidb/pkg/util/chunk"
|
|
"github.com/pingcap/tipb/go-tipb"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestSetFlenDecimal4RealOrDecimal(t *testing.T) {
|
|
ctx := exprstatic.NewEvalContext()
|
|
ret := &types.FieldType{}
|
|
a := &types.FieldType{}
|
|
a.SetDecimal(1)
|
|
a.SetFlen(3)
|
|
|
|
b := &types.FieldType{}
|
|
b.SetDecimal(0)
|
|
b.SetFlag(2)
|
|
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
|
|
require.Equal(t, 1, ret.GetDecimal())
|
|
require.Equal(t, 4, ret.GetFlen())
|
|
|
|
b.SetFlen(65)
|
|
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
|
|
require.Equal(t, 1, ret.GetDecimal())
|
|
require.Equal(t, mysql.MaxRealWidth, ret.GetFlen())
|
|
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, false, false)
|
|
require.Equal(t, 1, ret.GetDecimal())
|
|
require.Equal(t, mysql.MaxDecimalWidth, ret.GetFlen())
|
|
|
|
b.SetFlen(types.UnspecifiedLength)
|
|
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
|
|
require.Equal(t, 1, ret.GetDecimal())
|
|
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())
|
|
|
|
b.SetDecimal(types.UnspecifiedLength)
|
|
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
|
|
require.Equal(t, types.UnspecifiedLength, ret.GetDecimal())
|
|
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())
|
|
|
|
ret = &types.FieldType{}
|
|
a = &types.FieldType{}
|
|
a.SetDecimal(1)
|
|
a.SetFlen(3)
|
|
|
|
b = &types.FieldType{}
|
|
b.SetDecimal(0)
|
|
b.SetFlen(2)
|
|
|
|
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
|
|
require.Equal(t, 1, ret.GetDecimal())
|
|
require.Equal(t, 5, ret.GetFlen())
|
|
|
|
b.SetFlen(65)
|
|
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
|
|
require.Equal(t, 1, ret.GetDecimal())
|
|
require.Equal(t, mysql.MaxRealWidth, ret.GetFlen())
|
|
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, false, true)
|
|
require.Equal(t, 1, ret.GetDecimal())
|
|
require.Equal(t, mysql.MaxDecimalWidth, ret.GetFlen())
|
|
|
|
b.SetFlen(types.UnspecifiedLength)
|
|
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
|
|
require.Equal(t, 1, ret.GetDecimal())
|
|
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())
|
|
|
|
b.SetDecimal(types.UnspecifiedLength)
|
|
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
|
|
require.Equal(t, types.UnspecifiedLength, ret.GetDecimal())
|
|
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())
|
|
}
|
|
|
|
func TestArithmeticPlus(t *testing.T) {
|
|
ctx := createContext(t)
|
|
// case: 1
|
|
args := []any{int64(12), int64(1)}
|
|
|
|
bf, err := funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
intSig, ok := bf.(*builtinArithmeticPlusIntSig)
|
|
require.True(t, ok)
|
|
require.NotNil(t, intSig)
|
|
|
|
intResult, isNull, err := intSig.evalInt(ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.False(t, isNull)
|
|
require.Equal(t, int64(13), intResult)
|
|
|
|
// case 2
|
|
args = []any{1.01001, -0.01}
|
|
|
|
bf, err = funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
realSig, ok := bf.(*builtinArithmeticPlusRealSig)
|
|
require.True(t, ok)
|
|
require.NotNil(t, realSig)
|
|
|
|
realResult, isNull, err := realSig.evalReal(ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.False(t, isNull)
|
|
require.Equal(t, 1.00001, realResult)
|
|
|
|
// case 3
|
|
args = []any{nil, -0.11101}
|
|
|
|
bf, err = funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
realSig, ok = bf.(*builtinArithmeticPlusRealSig)
|
|
require.True(t, ok)
|
|
require.NotNil(t, realSig)
|
|
|
|
realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.True(t, isNull)
|
|
require.Equal(t, float64(0), realResult)
|
|
|
|
// case 4
|
|
args = []any{nil, nil}
|
|
|
|
bf, err = funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
realSig, ok = bf.(*builtinArithmeticPlusRealSig)
|
|
require.True(t, ok)
|
|
require.NotNil(t, realSig)
|
|
|
|
realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.True(t, isNull)
|
|
require.Equal(t, float64(0), realResult)
|
|
|
|
// case 5
|
|
hexStr, err := types.ParseHexStr("0x20000000000000")
|
|
require.NoError(t, err)
|
|
args = []any{hexStr, int64(1)}
|
|
|
|
bf, err = funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
intSig, ok = bf.(*builtinArithmeticPlusIntSig)
|
|
require.True(t, ok)
|
|
require.NotNil(t, intSig)
|
|
|
|
intResult, _, err = intSig.evalInt(ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(9007199254740993), intResult)
|
|
|
|
bitStr, err := types.NewBitLiteral("0b00011")
|
|
require.NoError(t, err)
|
|
args = []any{bitStr, int64(1)}
|
|
|
|
bf, err = funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
|
|
//check the result type is int
|
|
intSig, ok = bf.(*builtinArithmeticPlusIntSig)
|
|
require.True(t, ok)
|
|
require.NotNil(t, intSig)
|
|
|
|
intResult, _, err = intSig.evalInt(ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(4), intResult)
|
|
}
|
|
|
|
func TestArithmeticMinus(t *testing.T) {
|
|
ctx := createContext(t)
|
|
// case: 1
|
|
args := []any{int64(12), int64(1)}
|
|
|
|
bf, err := funcs[ast.Minus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
intSig, ok := bf.(*builtinArithmeticMinusIntSig)
|
|
require.True(t, ok)
|
|
require.NotNil(t, intSig)
|
|
|
|
intResult, isNull, err := intSig.evalInt(ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.False(t, isNull)
|
|
require.Equal(t, int64(11), intResult)
|
|
|
|
// case 2
|
|
args = []any{1.01001, -0.01}
|
|
|
|
bf, err = funcs[ast.Minus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
realSig, ok := bf.(*builtinArithmeticMinusRealSig)
|
|
require.True(t, ok)
|
|
require.NotNil(t, realSig)
|
|
|
|
realResult, isNull, err := realSig.evalReal(ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.False(t, isNull)
|
|
require.Equal(t, 1.02001, realResult)
|
|
|
|
// case 3
|
|
args = []any{nil, -0.11101}
|
|
|
|
bf, err = funcs[ast.Minus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
realSig, ok = bf.(*builtinArithmeticMinusRealSig)
|
|
require.True(t, ok)
|
|
require.NotNil(t, realSig)
|
|
|
|
realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.True(t, isNull)
|
|
require.Equal(t, float64(0), realResult)
|
|
|
|
// case 4
|
|
args = []any{1.01, nil}
|
|
|
|
bf, err = funcs[ast.Minus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
realSig, ok = bf.(*builtinArithmeticMinusRealSig)
|
|
require.True(t, ok)
|
|
require.NotNil(t, realSig)
|
|
|
|
realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.True(t, isNull)
|
|
require.Equal(t, float64(0), realResult)
|
|
|
|
// case 5
|
|
args = []any{nil, nil}
|
|
|
|
bf, err = funcs[ast.Minus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
realSig, ok = bf.(*builtinArithmeticMinusRealSig)
|
|
require.True(t, ok)
|
|
require.NotNil(t, realSig)
|
|
|
|
realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
require.True(t, isNull)
|
|
require.Equal(t, float64(0), realResult)
|
|
}
|
|
|
|
func TestArithmeticMultiply(t *testing.T) {
|
|
ctx := createContext(t)
|
|
testCases := []struct {
|
|
args []any
|
|
expect []any
|
|
err error
|
|
}{
|
|
{
|
|
args: []any{int64(11), int64(11)},
|
|
expect: []any{int64(121), nil},
|
|
},
|
|
{
|
|
args: []any{int64(-1), int64(math.MinInt64)},
|
|
expect: []any{nil, "BIGINT value is out of range in '\\(-1 \\* -9223372036854775808\\)'$"},
|
|
},
|
|
{
|
|
args: []any{int64(math.MinInt64), int64(-1)},
|
|
expect: []any{nil, "BIGINT value is out of range in '\\(-9223372036854775808 \\* -1\\)'$"},
|
|
},
|
|
{
|
|
args: []any{uint64(11), uint64(11)},
|
|
expect: []any{int64(121), nil},
|
|
},
|
|
{
|
|
args: []any{float64(11), float64(11)},
|
|
expect: []any{float64(121), nil},
|
|
},
|
|
{
|
|
args: []any{nil, -0.11101},
|
|
expect: []any{nil, nil},
|
|
},
|
|
{
|
|
args: []any{1.01, nil},
|
|
expect: []any{nil, nil},
|
|
},
|
|
{
|
|
args: []any{nil, nil},
|
|
expect: []any{nil, nil},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
sig, err := funcs[ast.Mul].getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, sig)
|
|
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
|
|
if tc.expect[1] == nil {
|
|
require.NoError(t, err)
|
|
testutil.DatumEqual(t, types.NewDatum(tc.expect[0]), val)
|
|
} else {
|
|
require.Error(t, err)
|
|
require.Regexp(t, tc.expect[1], err.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestArithmeticDivide(t *testing.T) {
|
|
ctx := createContext(t)
|
|
|
|
testCases := []struct {
|
|
args []any
|
|
expect any
|
|
}{
|
|
{
|
|
args: []any{11.1111111, 11.1},
|
|
expect: 1.001001,
|
|
},
|
|
{
|
|
args: []any{11.1111111, float64(0)},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{int64(11), int64(11)},
|
|
expect: float64(1),
|
|
},
|
|
{
|
|
args: []any{int64(11), int64(2)},
|
|
expect: 5.5,
|
|
},
|
|
{
|
|
args: []any{int64(11), int64(0)},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{uint64(11), uint64(11)},
|
|
expect: float64(1),
|
|
},
|
|
{
|
|
args: []any{uint64(11), uint64(2)},
|
|
expect: 5.5,
|
|
},
|
|
{
|
|
args: []any{uint64(11), uint64(0)},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{nil, -0.11101},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{1.01, nil},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{nil, nil},
|
|
expect: nil,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
sig, err := funcs[ast.Div].getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, sig)
|
|
switch sig.(type) {
|
|
case *builtinArithmeticIntDivideIntSig:
|
|
require.Equal(t, tipb.ScalarFuncSig_IntDivideInt, sig.PbCode())
|
|
case *builtinArithmeticIntDivideDecimalSig:
|
|
require.Equal(t, tipb.ScalarFuncSig_IntDivideDecimal, sig.PbCode())
|
|
}
|
|
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
|
|
require.NoError(t, err)
|
|
testutil.DatumEqual(t, types.NewDatum(tc.expect), val)
|
|
}
|
|
}
|
|
|
|
func TestArithmeticIntDivide(t *testing.T) {
|
|
ctx := createContext(t)
|
|
testCases := []struct {
|
|
args []any
|
|
expect []any
|
|
}{
|
|
{
|
|
args: []any{int64(13), int64(11)},
|
|
expect: []any{int64(1), nil},
|
|
},
|
|
{
|
|
args: []any{int64(-13), int64(11)},
|
|
expect: []any{int64(-1), nil},
|
|
},
|
|
{
|
|
args: []any{int64(13), int64(-11)},
|
|
expect: []any{int64(-1), nil},
|
|
},
|
|
{
|
|
args: []any{int64(-13), int64(-11)},
|
|
expect: []any{int64(1), nil},
|
|
},
|
|
{
|
|
args: []any{int64(33), int64(11)},
|
|
expect: []any{int64(3), nil},
|
|
},
|
|
{
|
|
args: []any{int64(-33), int64(11)},
|
|
expect: []any{int64(-3), nil},
|
|
},
|
|
{
|
|
args: []any{int64(33), int64(-11)},
|
|
expect: []any{int64(-3), nil},
|
|
},
|
|
{
|
|
args: []any{int64(-33), int64(-11)},
|
|
expect: []any{int64(3), nil},
|
|
},
|
|
{
|
|
args: []any{int64(11), int64(0)},
|
|
expect: []any{nil, nil},
|
|
},
|
|
{
|
|
args: []any{int64(-11), int64(0)},
|
|
expect: []any{nil, nil},
|
|
},
|
|
{
|
|
args: []any{11.01, 1.1},
|
|
expect: []any{int64(10), nil},
|
|
},
|
|
{
|
|
args: []any{-11.01, 1.1},
|
|
expect: []any{int64(-10), nil},
|
|
},
|
|
{
|
|
args: []any{11.01, -1.1},
|
|
expect: []any{int64(-10), nil},
|
|
},
|
|
{
|
|
args: []any{-11.01, -1.1},
|
|
expect: []any{int64(10), nil},
|
|
},
|
|
{
|
|
args: []any{nil, -0.11101},
|
|
expect: []any{nil, nil},
|
|
},
|
|
{
|
|
args: []any{1.01, nil},
|
|
expect: []any{nil, nil},
|
|
},
|
|
{
|
|
args: []any{nil, int64(-1001)},
|
|
expect: []any{nil, nil},
|
|
},
|
|
{
|
|
args: []any{int64(101), nil},
|
|
expect: []any{nil, nil},
|
|
},
|
|
{
|
|
args: []any{nil, nil},
|
|
expect: []any{nil, nil},
|
|
},
|
|
{
|
|
args: []any{123456789100000.0, -0.00001},
|
|
expect: []any{nil, ".*BIGINT value is out of range in '\\(123456789100000 DIV -0.00001\\)'"},
|
|
},
|
|
{
|
|
args: []any{int64(-9223372036854775808), float64(-1)},
|
|
expect: []any{nil, ".*BIGINT value is out of range in '\\(-9223372036854775808 DIV -1\\)'"},
|
|
},
|
|
{
|
|
args: []any{uint64(1), float64(-2)},
|
|
expect: []any{0, nil},
|
|
},
|
|
{
|
|
args: []any{uint64(1), float64(-1)},
|
|
expect: []any{nil, ".*BIGINT UNSIGNED value is out of range in '\\(1 DIV -1\\)'"},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
sig, err := funcs[ast.IntDiv].getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, sig)
|
|
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
|
|
if tc.expect[1] == nil {
|
|
require.NoError(t, err)
|
|
testutil.DatumEqual(t, types.NewDatum(tc.expect[0]), val)
|
|
} else {
|
|
require.Error(t, err)
|
|
require.Regexp(t, tc.expect[1], err.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestArithmeticMod(t *testing.T) {
|
|
ctx := createContext(t)
|
|
testCases := []struct {
|
|
args []any
|
|
expect any
|
|
}{
|
|
{
|
|
args: []any{int64(13), int64(11)},
|
|
expect: int64(2),
|
|
},
|
|
{
|
|
args: []any{int64(13), int64(11)},
|
|
expect: int64(2),
|
|
},
|
|
{
|
|
args: []any{int64(13), int64(0)},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{uint64(13), int64(0)},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{int64(13), uint64(0)},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{uint64(math.MaxInt64 + 1), int64(math.MinInt64)},
|
|
expect: int64(0),
|
|
},
|
|
{
|
|
args: []any{int64(-22), uint64(10)},
|
|
expect: int64(-2),
|
|
},
|
|
{
|
|
args: []any{int64(math.MinInt64), uint64(3)},
|
|
expect: int64(-2),
|
|
},
|
|
{
|
|
args: []any{int64(-13), int64(11)},
|
|
expect: int64(-2),
|
|
},
|
|
{
|
|
args: []any{int64(13), int64(-11)},
|
|
expect: int64(2),
|
|
},
|
|
{
|
|
args: []any{int64(-13), int64(-11)},
|
|
expect: int64(-2),
|
|
},
|
|
{
|
|
args: []any{int64(33), int64(11)},
|
|
expect: int64(0),
|
|
},
|
|
{
|
|
args: []any{int64(-33), int64(11)},
|
|
expect: int64(0),
|
|
},
|
|
{
|
|
args: []any{int64(33), int64(-11)},
|
|
expect: int64(0),
|
|
},
|
|
{
|
|
args: []any{int64(-33), int64(-11)},
|
|
expect: int64(0),
|
|
},
|
|
{
|
|
args: []any{int64(11), int64(0)},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{int64(-11), int64(0)},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{int64(1), 1.1},
|
|
expect: float64(1),
|
|
},
|
|
{
|
|
args: []any{int64(-1), 1.1},
|
|
expect: float64(-1),
|
|
},
|
|
{
|
|
args: []any{int64(1), -1.1},
|
|
expect: float64(1),
|
|
},
|
|
{
|
|
args: []any{int64(-1), -1.1},
|
|
expect: float64(-1),
|
|
},
|
|
{
|
|
args: []any{nil, -0.11101},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{1.01, nil},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{nil, int64(-1001)},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{int64(101), nil},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{nil, nil},
|
|
expect: nil,
|
|
},
|
|
{
|
|
args: []any{"1231", 12},
|
|
expect: 7,
|
|
},
|
|
{
|
|
args: []any{"1231", "12"},
|
|
expect: float64(7),
|
|
},
|
|
{
|
|
args: []any{types.Duration{Duration: 45296 * time.Second}, 122},
|
|
expect: 114,
|
|
},
|
|
{
|
|
args: []any{types.Set{Value: 7, Name: "abc"}, "12"},
|
|
expect: float64(7),
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
sig, err := funcs[ast.Mod].getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, sig)
|
|
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
|
|
switch sig.(type) {
|
|
case *builtinArithmeticModRealSig:
|
|
require.Equal(t, tipb.ScalarFuncSig_ModReal, sig.PbCode())
|
|
case *builtinArithmeticModIntUnsignedUnsignedSig:
|
|
require.Equal(t, tipb.ScalarFuncSig_ModIntUnsignedUnsigned, sig.PbCode())
|
|
case *builtinArithmeticModIntUnsignedSignedSig:
|
|
require.Equal(t, tipb.ScalarFuncSig_ModIntUnsignedSigned, sig.PbCode())
|
|
case *builtinArithmeticModIntSignedUnsignedSig:
|
|
require.Equal(t, tipb.ScalarFuncSig_ModIntSignedUnsigned, sig.PbCode())
|
|
case *builtinArithmeticModIntSignedSignedSig:
|
|
require.Equal(t, tipb.ScalarFuncSig_ModIntSignedSigned, sig.PbCode())
|
|
case *builtinArithmeticModDecimalSig:
|
|
require.Equal(t, tipb.ScalarFuncSig_ModDecimal, sig.PbCode())
|
|
}
|
|
require.NoError(t, err)
|
|
testutil.DatumEqual(t, types.NewDatum(tc.expect), val)
|
|
}
|
|
}
|
|
|
|
func TestDecimalErrOverflow(t *testing.T) {
|
|
ctx := createContext(t)
|
|
testCases := []struct {
|
|
args []float64
|
|
opd string
|
|
sig tipb.ScalarFuncSig
|
|
errStr string
|
|
}{
|
|
{
|
|
args: []float64{8.1e80, 8.1e80},
|
|
opd: ast.Plus,
|
|
sig: tipb.ScalarFuncSig_PlusDecimal,
|
|
errStr: "[types:1690]DECIMAL value is out of range in '(810000000000000000000000000000000000000000000000000000000000000000000000000000000 + 810000000000000000000000000000000000000000000000000000000000000000000000000000000)'",
|
|
},
|
|
{
|
|
args: []float64{8.1e80, -8.1e80},
|
|
opd: ast.Minus,
|
|
sig: tipb.ScalarFuncSig_MinusDecimal,
|
|
errStr: "[types:1690]DECIMAL value is out of range in '(810000000000000000000000000000000000000000000000000000000000000000000000000000000 - -810000000000000000000000000000000000000000000000000000000000000000000000000000000)'",
|
|
},
|
|
{
|
|
args: []float64{8.1e80, 8.1e80},
|
|
opd: ast.Mul,
|
|
sig: tipb.ScalarFuncSig_MultiplyDecimal,
|
|
errStr: "[types:1690]DECIMAL value is out of range in '(810000000000000000000000000000000000000000000000000000000000000000000000000000000 * 810000000000000000000000000000000000000000000000000000000000000000000000000000000)'",
|
|
},
|
|
{
|
|
args: []float64{8.1e80, 0.1},
|
|
opd: ast.Div,
|
|
sig: tipb.ScalarFuncSig_DivideDecimal,
|
|
errStr: "[types:1690]DECIMAL value is out of range in '(810000000000000000000000000000000000000000000000000000000000000000000000000000000 / 0.1)'",
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
dec1, dec2 := types.NewDecFromFloatForTest(tc.args[0]), types.NewDecFromFloatForTest(tc.args[1])
|
|
bf, err := funcs[tc.opd].getFunction(ctx, datumsToConstants(types.MakeDatums(dec1, dec2)))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
require.Equal(t, tc.sig, bf.PbCode())
|
|
_, err = evalBuiltinFunc(bf, ctx, chunk.Row{})
|
|
require.EqualError(t, err, tc.errStr)
|
|
}
|
|
}
|
|
|
|
// TestArithmeticOverflowErrorMessageWithColumnName tests that overflow error messages
|
|
// display the actual column name instead of "Column#N".
|
|
// This is a regression test for https://github.com/pingcap/tidb/issues/17993
|
|
func TestArithmeticOverflowErrorMessageWithColumnName(t *testing.T) {
|
|
// Test case 1: Simple column multiplication (col * constant)
|
|
t.Run("SimpleColumn", func(t *testing.T) {
|
|
ctx := createContext(t)
|
|
|
|
// Create a column with OrigName set (simulating a real table column)
|
|
col := &Column{
|
|
RetType: types.NewFieldType(mysql.TypeLonglong),
|
|
ID: 1,
|
|
UniqueID: 1,
|
|
Index: 0,
|
|
OrigName: "test.t.col1",
|
|
}
|
|
|
|
// Create a constant for -1
|
|
constant := &Constant{
|
|
Value: types.NewIntDatum(-1),
|
|
RetType: types.NewFieldType(mysql.TypeLonglong),
|
|
}
|
|
|
|
// Build the multiply function with column * constant
|
|
// When column value is MinInt64 and multiplied by -1, it overflows
|
|
bf, err := funcs[ast.Mul].getFunction(ctx, []Expression{col, constant})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, bf)
|
|
|
|
// Create a mock chunk with MinInt64 that will cause overflow when multiplied by -1
|
|
chk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1)
|
|
chk.AppendInt64(0, math.MinInt64)
|
|
row := chk.GetRow(0)
|
|
|
|
// Execute the function - should cause overflow
|
|
_, _, err = bf.evalInt(ctx, row)
|
|
require.Error(t, err)
|
|
|
|
// The error message should contain the actual column name "test.t.col1"
|
|
// instead of "Column#1"
|
|
errMsg := err.Error()
|
|
require.Contains(t, errMsg, "test.t.col1", "Error message should contain the actual column name")
|
|
require.NotContains(t, errMsg, "Column#", "Error message should not contain 'Column#'")
|
|
})
|
|
|
|
// Test case 2: Derived column - (col1 + col2) * constant
|
|
// This tests the case: (t5.col7 + t5.col2) * ABS(-9223372036854775807)
|
|
t.Run("DerivedColumn", func(t *testing.T) {
|
|
ctx := createContext(t)
|
|
|
|
// Create two columns with OrigName set (simulating real table columns)
|
|
col1 := &Column{
|
|
RetType: types.NewFieldType(mysql.TypeLonglong),
|
|
ID: 1,
|
|
UniqueID: 1,
|
|
Index: 0,
|
|
OrigName: "t5.col7",
|
|
}
|
|
col2 := &Column{
|
|
RetType: types.NewFieldType(mysql.TypeLonglong),
|
|
ID: 2,
|
|
UniqueID: 2,
|
|
Index: 1,
|
|
OrigName: "t5.col2",
|
|
}
|
|
|
|
// Build the addition function (col1 + col2) - this creates a derived column
|
|
addFunc, err := funcs[ast.Plus].getFunction(ctx, []Expression{col1, col2})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, addFunc)
|
|
|
|
// Wrap the addition function as a ScalarFunction
|
|
addExpr := &ScalarFunction{
|
|
FuncName: ast.NewCIStr(ast.Plus),
|
|
RetType: types.NewFieldType(mysql.TypeLonglong),
|
|
Function: addFunc,
|
|
}
|
|
|
|
// Create a constant for a large value that will cause overflow when multiplied
|
|
constant := &Constant{
|
|
Value: types.NewIntDatum(math.MaxInt64),
|
|
RetType: types.NewFieldType(mysql.TypeLonglong),
|
|
}
|
|
|
|
// Build the multiply function: (col1 + col2) * constant
|
|
mulFunc, err := funcs[ast.Mul].getFunction(ctx, []Expression{addExpr, constant})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, mulFunc)
|
|
|
|
// Create a mock chunk with values that will cause overflow
|
|
// col1 = 2, col2 = 1, so (col1 + col2) = 3, and 3 * MaxInt64 overflows
|
|
chk := chunk.NewChunkWithCapacity([]*types.FieldType{
|
|
types.NewFieldType(mysql.TypeLonglong),
|
|
types.NewFieldType(mysql.TypeLonglong),
|
|
}, 1)
|
|
chk.AppendInt64(0, 2)
|
|
chk.AppendInt64(1, 1)
|
|
row := chk.GetRow(0)
|
|
|
|
// Execute the function - should cause overflow
|
|
_, _, err = mulFunc.evalInt(ctx, row)
|
|
require.Error(t, err)
|
|
|
|
// The error message should contain the actual column names "t5.col7" and "t5.col2"
|
|
// instead of "Column#1" and "Column#2"
|
|
errMsg := err.Error()
|
|
require.Contains(t, errMsg, "t5.col7", "Error message should contain the actual column name t5.col7")
|
|
require.Contains(t, errMsg, "t5.col2", "Error message should contain the actual column name t5.col2")
|
|
require.NotContains(t, errMsg, "Column#", "Error message should not contain 'Column#'")
|
|
})
|
|
}
|