Files
tidb/pkg/expression/builtin_arithmetic_test.go

817 lines
22 KiB
Go

// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"math"
"testing"
"time"
"github.com/pingcap/tidb/pkg/expression/exprstatic"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/testkit/testutil"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tipb/go-tipb"
"github.com/stretchr/testify/require"
)
func TestSetFlenDecimal4RealOrDecimal(t *testing.T) {
ctx := exprstatic.NewEvalContext()
ret := &types.FieldType{}
a := &types.FieldType{}
a.SetDecimal(1)
a.SetFlen(3)
b := &types.FieldType{}
b.SetDecimal(0)
b.SetFlag(2)
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, 4, ret.GetFlen())
b.SetFlen(65)
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, mysql.MaxRealWidth, ret.GetFlen())
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, false, false)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, mysql.MaxDecimalWidth, ret.GetFlen())
b.SetFlen(types.UnspecifiedLength)
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())
b.SetDecimal(types.UnspecifiedLength)
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
require.Equal(t, types.UnspecifiedLength, ret.GetDecimal())
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())
ret = &types.FieldType{}
a = &types.FieldType{}
a.SetDecimal(1)
a.SetFlen(3)
b = &types.FieldType{}
b.SetDecimal(0)
b.SetFlen(2)
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, 5, ret.GetFlen())
b.SetFlen(65)
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, mysql.MaxRealWidth, ret.GetFlen())
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, false, true)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, mysql.MaxDecimalWidth, ret.GetFlen())
b.SetFlen(types.UnspecifiedLength)
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())
b.SetDecimal(types.UnspecifiedLength)
setFlenDecimal4RealOrDecimal(ctx, ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
require.Equal(t, types.UnspecifiedLength, ret.GetDecimal())
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())
}
func TestArithmeticPlus(t *testing.T) {
ctx := createContext(t)
// case: 1
args := []any{int64(12), int64(1)}
bf, err := funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
require.NoError(t, err)
require.NotNil(t, bf)
intSig, ok := bf.(*builtinArithmeticPlusIntSig)
require.True(t, ok)
require.NotNil(t, intSig)
intResult, isNull, err := intSig.evalInt(ctx, chunk.Row{})
require.NoError(t, err)
require.False(t, isNull)
require.Equal(t, int64(13), intResult)
// case 2
args = []any{1.01001, -0.01}
bf, err = funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
require.NoError(t, err)
require.NotNil(t, bf)
realSig, ok := bf.(*builtinArithmeticPlusRealSig)
require.True(t, ok)
require.NotNil(t, realSig)
realResult, isNull, err := realSig.evalReal(ctx, chunk.Row{})
require.NoError(t, err)
require.False(t, isNull)
require.Equal(t, 1.00001, realResult)
// case 3
args = []any{nil, -0.11101}
bf, err = funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
require.NoError(t, err)
require.NotNil(t, bf)
realSig, ok = bf.(*builtinArithmeticPlusRealSig)
require.True(t, ok)
require.NotNil(t, realSig)
realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
require.NoError(t, err)
require.True(t, isNull)
require.Equal(t, float64(0), realResult)
// case 4
args = []any{nil, nil}
bf, err = funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
require.NoError(t, err)
require.NotNil(t, bf)
realSig, ok = bf.(*builtinArithmeticPlusRealSig)
require.True(t, ok)
require.NotNil(t, realSig)
realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
require.NoError(t, err)
require.True(t, isNull)
require.Equal(t, float64(0), realResult)
// case 5
hexStr, err := types.ParseHexStr("0x20000000000000")
require.NoError(t, err)
args = []any{hexStr, int64(1)}
bf, err = funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
require.NoError(t, err)
require.NotNil(t, bf)
intSig, ok = bf.(*builtinArithmeticPlusIntSig)
require.True(t, ok)
require.NotNil(t, intSig)
intResult, _, err = intSig.evalInt(ctx, chunk.Row{})
require.NoError(t, err)
require.Equal(t, int64(9007199254740993), intResult)
bitStr, err := types.NewBitLiteral("0b00011")
require.NoError(t, err)
args = []any{bitStr, int64(1)}
bf, err = funcs[ast.Plus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
require.NoError(t, err)
require.NotNil(t, bf)
//check the result type is int
intSig, ok = bf.(*builtinArithmeticPlusIntSig)
require.True(t, ok)
require.NotNil(t, intSig)
intResult, _, err = intSig.evalInt(ctx, chunk.Row{})
require.NoError(t, err)
require.Equal(t, int64(4), intResult)
}
func TestArithmeticMinus(t *testing.T) {
ctx := createContext(t)
// case: 1
args := []any{int64(12), int64(1)}
bf, err := funcs[ast.Minus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
require.NoError(t, err)
require.NotNil(t, bf)
intSig, ok := bf.(*builtinArithmeticMinusIntSig)
require.True(t, ok)
require.NotNil(t, intSig)
intResult, isNull, err := intSig.evalInt(ctx, chunk.Row{})
require.NoError(t, err)
require.False(t, isNull)
require.Equal(t, int64(11), intResult)
// case 2
args = []any{1.01001, -0.01}
bf, err = funcs[ast.Minus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
require.NoError(t, err)
require.NotNil(t, bf)
realSig, ok := bf.(*builtinArithmeticMinusRealSig)
require.True(t, ok)
require.NotNil(t, realSig)
realResult, isNull, err := realSig.evalReal(ctx, chunk.Row{})
require.NoError(t, err)
require.False(t, isNull)
require.Equal(t, 1.02001, realResult)
// case 3
args = []any{nil, -0.11101}
bf, err = funcs[ast.Minus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
require.NoError(t, err)
require.NotNil(t, bf)
realSig, ok = bf.(*builtinArithmeticMinusRealSig)
require.True(t, ok)
require.NotNil(t, realSig)
realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
require.NoError(t, err)
require.True(t, isNull)
require.Equal(t, float64(0), realResult)
// case 4
args = []any{1.01, nil}
bf, err = funcs[ast.Minus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
require.NoError(t, err)
require.NotNil(t, bf)
realSig, ok = bf.(*builtinArithmeticMinusRealSig)
require.True(t, ok)
require.NotNil(t, realSig)
realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
require.NoError(t, err)
require.True(t, isNull)
require.Equal(t, float64(0), realResult)
// case 5
args = []any{nil, nil}
bf, err = funcs[ast.Minus].getFunction(ctx, datumsToConstants(types.MakeDatums(args...)))
require.NoError(t, err)
require.NotNil(t, bf)
realSig, ok = bf.(*builtinArithmeticMinusRealSig)
require.True(t, ok)
require.NotNil(t, realSig)
realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
require.NoError(t, err)
require.True(t, isNull)
require.Equal(t, float64(0), realResult)
}
func TestArithmeticMultiply(t *testing.T) {
ctx := createContext(t)
testCases := []struct {
args []any
expect []any
err error
}{
{
args: []any{int64(11), int64(11)},
expect: []any{int64(121), nil},
},
{
args: []any{int64(-1), int64(math.MinInt64)},
expect: []any{nil, "BIGINT value is out of range in '\\(-1 \\* -9223372036854775808\\)'$"},
},
{
args: []any{int64(math.MinInt64), int64(-1)},
expect: []any{nil, "BIGINT value is out of range in '\\(-9223372036854775808 \\* -1\\)'$"},
},
{
args: []any{uint64(11), uint64(11)},
expect: []any{int64(121), nil},
},
{
args: []any{float64(11), float64(11)},
expect: []any{float64(121), nil},
},
{
args: []any{nil, -0.11101},
expect: []any{nil, nil},
},
{
args: []any{1.01, nil},
expect: []any{nil, nil},
},
{
args: []any{nil, nil},
expect: []any{nil, nil},
},
}
for _, tc := range testCases {
sig, err := funcs[ast.Mul].getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
require.NoError(t, err)
require.NotNil(t, sig)
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
if tc.expect[1] == nil {
require.NoError(t, err)
testutil.DatumEqual(t, types.NewDatum(tc.expect[0]), val)
} else {
require.Error(t, err)
require.Regexp(t, tc.expect[1], err.Error())
}
}
}
func TestArithmeticDivide(t *testing.T) {
ctx := createContext(t)
testCases := []struct {
args []any
expect any
}{
{
args: []any{11.1111111, 11.1},
expect: 1.001001,
},
{
args: []any{11.1111111, float64(0)},
expect: nil,
},
{
args: []any{int64(11), int64(11)},
expect: float64(1),
},
{
args: []any{int64(11), int64(2)},
expect: 5.5,
},
{
args: []any{int64(11), int64(0)},
expect: nil,
},
{
args: []any{uint64(11), uint64(11)},
expect: float64(1),
},
{
args: []any{uint64(11), uint64(2)},
expect: 5.5,
},
{
args: []any{uint64(11), uint64(0)},
expect: nil,
},
{
args: []any{nil, -0.11101},
expect: nil,
},
{
args: []any{1.01, nil},
expect: nil,
},
{
args: []any{nil, nil},
expect: nil,
},
}
for _, tc := range testCases {
sig, err := funcs[ast.Div].getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
require.NoError(t, err)
require.NotNil(t, sig)
switch sig.(type) {
case *builtinArithmeticIntDivideIntSig:
require.Equal(t, tipb.ScalarFuncSig_IntDivideInt, sig.PbCode())
case *builtinArithmeticIntDivideDecimalSig:
require.Equal(t, tipb.ScalarFuncSig_IntDivideDecimal, sig.PbCode())
}
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
require.NoError(t, err)
testutil.DatumEqual(t, types.NewDatum(tc.expect), val)
}
}
func TestArithmeticIntDivide(t *testing.T) {
ctx := createContext(t)
testCases := []struct {
args []any
expect []any
}{
{
args: []any{int64(13), int64(11)},
expect: []any{int64(1), nil},
},
{
args: []any{int64(-13), int64(11)},
expect: []any{int64(-1), nil},
},
{
args: []any{int64(13), int64(-11)},
expect: []any{int64(-1), nil},
},
{
args: []any{int64(-13), int64(-11)},
expect: []any{int64(1), nil},
},
{
args: []any{int64(33), int64(11)},
expect: []any{int64(3), nil},
},
{
args: []any{int64(-33), int64(11)},
expect: []any{int64(-3), nil},
},
{
args: []any{int64(33), int64(-11)},
expect: []any{int64(-3), nil},
},
{
args: []any{int64(-33), int64(-11)},
expect: []any{int64(3), nil},
},
{
args: []any{int64(11), int64(0)},
expect: []any{nil, nil},
},
{
args: []any{int64(-11), int64(0)},
expect: []any{nil, nil},
},
{
args: []any{11.01, 1.1},
expect: []any{int64(10), nil},
},
{
args: []any{-11.01, 1.1},
expect: []any{int64(-10), nil},
},
{
args: []any{11.01, -1.1},
expect: []any{int64(-10), nil},
},
{
args: []any{-11.01, -1.1},
expect: []any{int64(10), nil},
},
{
args: []any{nil, -0.11101},
expect: []any{nil, nil},
},
{
args: []any{1.01, nil},
expect: []any{nil, nil},
},
{
args: []any{nil, int64(-1001)},
expect: []any{nil, nil},
},
{
args: []any{int64(101), nil},
expect: []any{nil, nil},
},
{
args: []any{nil, nil},
expect: []any{nil, nil},
},
{
args: []any{123456789100000.0, -0.00001},
expect: []any{nil, ".*BIGINT value is out of range in '\\(123456789100000 DIV -0.00001\\)'"},
},
{
args: []any{int64(-9223372036854775808), float64(-1)},
expect: []any{nil, ".*BIGINT value is out of range in '\\(-9223372036854775808 DIV -1\\)'"},
},
{
args: []any{uint64(1), float64(-2)},
expect: []any{0, nil},
},
{
args: []any{uint64(1), float64(-1)},
expect: []any{nil, ".*BIGINT UNSIGNED value is out of range in '\\(1 DIV -1\\)'"},
},
}
for _, tc := range testCases {
sig, err := funcs[ast.IntDiv].getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
require.NoError(t, err)
require.NotNil(t, sig)
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
if tc.expect[1] == nil {
require.NoError(t, err)
testutil.DatumEqual(t, types.NewDatum(tc.expect[0]), val)
} else {
require.Error(t, err)
require.Regexp(t, tc.expect[1], err.Error())
}
}
}
func TestArithmeticMod(t *testing.T) {
ctx := createContext(t)
testCases := []struct {
args []any
expect any
}{
{
args: []any{int64(13), int64(11)},
expect: int64(2),
},
{
args: []any{int64(13), int64(11)},
expect: int64(2),
},
{
args: []any{int64(13), int64(0)},
expect: nil,
},
{
args: []any{uint64(13), int64(0)},
expect: nil,
},
{
args: []any{int64(13), uint64(0)},
expect: nil,
},
{
args: []any{uint64(math.MaxInt64 + 1), int64(math.MinInt64)},
expect: int64(0),
},
{
args: []any{int64(-22), uint64(10)},
expect: int64(-2),
},
{
args: []any{int64(math.MinInt64), uint64(3)},
expect: int64(-2),
},
{
args: []any{int64(-13), int64(11)},
expect: int64(-2),
},
{
args: []any{int64(13), int64(-11)},
expect: int64(2),
},
{
args: []any{int64(-13), int64(-11)},
expect: int64(-2),
},
{
args: []any{int64(33), int64(11)},
expect: int64(0),
},
{
args: []any{int64(-33), int64(11)},
expect: int64(0),
},
{
args: []any{int64(33), int64(-11)},
expect: int64(0),
},
{
args: []any{int64(-33), int64(-11)},
expect: int64(0),
},
{
args: []any{int64(11), int64(0)},
expect: nil,
},
{
args: []any{int64(-11), int64(0)},
expect: nil,
},
{
args: []any{int64(1), 1.1},
expect: float64(1),
},
{
args: []any{int64(-1), 1.1},
expect: float64(-1),
},
{
args: []any{int64(1), -1.1},
expect: float64(1),
},
{
args: []any{int64(-1), -1.1},
expect: float64(-1),
},
{
args: []any{nil, -0.11101},
expect: nil,
},
{
args: []any{1.01, nil},
expect: nil,
},
{
args: []any{nil, int64(-1001)},
expect: nil,
},
{
args: []any{int64(101), nil},
expect: nil,
},
{
args: []any{nil, nil},
expect: nil,
},
{
args: []any{"1231", 12},
expect: 7,
},
{
args: []any{"1231", "12"},
expect: float64(7),
},
{
args: []any{types.Duration{Duration: 45296 * time.Second}, 122},
expect: 114,
},
{
args: []any{types.Set{Value: 7, Name: "abc"}, "12"},
expect: float64(7),
},
}
for _, tc := range testCases {
sig, err := funcs[ast.Mod].getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
require.NoError(t, err)
require.NotNil(t, sig)
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
switch sig.(type) {
case *builtinArithmeticModRealSig:
require.Equal(t, tipb.ScalarFuncSig_ModReal, sig.PbCode())
case *builtinArithmeticModIntUnsignedUnsignedSig:
require.Equal(t, tipb.ScalarFuncSig_ModIntUnsignedUnsigned, sig.PbCode())
case *builtinArithmeticModIntUnsignedSignedSig:
require.Equal(t, tipb.ScalarFuncSig_ModIntUnsignedSigned, sig.PbCode())
case *builtinArithmeticModIntSignedUnsignedSig:
require.Equal(t, tipb.ScalarFuncSig_ModIntSignedUnsigned, sig.PbCode())
case *builtinArithmeticModIntSignedSignedSig:
require.Equal(t, tipb.ScalarFuncSig_ModIntSignedSigned, sig.PbCode())
case *builtinArithmeticModDecimalSig:
require.Equal(t, tipb.ScalarFuncSig_ModDecimal, sig.PbCode())
}
require.NoError(t, err)
testutil.DatumEqual(t, types.NewDatum(tc.expect), val)
}
}
func TestDecimalErrOverflow(t *testing.T) {
ctx := createContext(t)
testCases := []struct {
args []float64
opd string
sig tipb.ScalarFuncSig
errStr string
}{
{
args: []float64{8.1e80, 8.1e80},
opd: ast.Plus,
sig: tipb.ScalarFuncSig_PlusDecimal,
errStr: "[types:1690]DECIMAL value is out of range in '(810000000000000000000000000000000000000000000000000000000000000000000000000000000 + 810000000000000000000000000000000000000000000000000000000000000000000000000000000)'",
},
{
args: []float64{8.1e80, -8.1e80},
opd: ast.Minus,
sig: tipb.ScalarFuncSig_MinusDecimal,
errStr: "[types:1690]DECIMAL value is out of range in '(810000000000000000000000000000000000000000000000000000000000000000000000000000000 - -810000000000000000000000000000000000000000000000000000000000000000000000000000000)'",
},
{
args: []float64{8.1e80, 8.1e80},
opd: ast.Mul,
sig: tipb.ScalarFuncSig_MultiplyDecimal,
errStr: "[types:1690]DECIMAL value is out of range in '(810000000000000000000000000000000000000000000000000000000000000000000000000000000 * 810000000000000000000000000000000000000000000000000000000000000000000000000000000)'",
},
{
args: []float64{8.1e80, 0.1},
opd: ast.Div,
sig: tipb.ScalarFuncSig_DivideDecimal,
errStr: "[types:1690]DECIMAL value is out of range in '(810000000000000000000000000000000000000000000000000000000000000000000000000000000 / 0.1)'",
},
}
for _, tc := range testCases {
dec1, dec2 := types.NewDecFromFloatForTest(tc.args[0]), types.NewDecFromFloatForTest(tc.args[1])
bf, err := funcs[tc.opd].getFunction(ctx, datumsToConstants(types.MakeDatums(dec1, dec2)))
require.NoError(t, err)
require.NotNil(t, bf)
require.Equal(t, tc.sig, bf.PbCode())
_, err = evalBuiltinFunc(bf, ctx, chunk.Row{})
require.EqualError(t, err, tc.errStr)
}
}
// TestArithmeticOverflowErrorMessageWithColumnName tests that overflow error messages
// display the actual column name instead of "Column#N".
// This is a regression test for https://github.com/pingcap/tidb/issues/17993
func TestArithmeticOverflowErrorMessageWithColumnName(t *testing.T) {
// Test case 1: Simple column multiplication (col * constant)
t.Run("SimpleColumn", func(t *testing.T) {
ctx := createContext(t)
// Create a column with OrigName set (simulating a real table column)
col := &Column{
RetType: types.NewFieldType(mysql.TypeLonglong),
ID: 1,
UniqueID: 1,
Index: 0,
OrigName: "test.t.col1",
}
// Create a constant for -1
constant := &Constant{
Value: types.NewIntDatum(-1),
RetType: types.NewFieldType(mysql.TypeLonglong),
}
// Build the multiply function with column * constant
// When column value is MinInt64 and multiplied by -1, it overflows
bf, err := funcs[ast.Mul].getFunction(ctx, []Expression{col, constant})
require.NoError(t, err)
require.NotNil(t, bf)
// Create a mock chunk with MinInt64 that will cause overflow when multiplied by -1
chk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1)
chk.AppendInt64(0, math.MinInt64)
row := chk.GetRow(0)
// Execute the function - should cause overflow
_, _, err = bf.evalInt(ctx, row)
require.Error(t, err)
// The error message should contain the actual column name "test.t.col1"
// instead of "Column#1"
errMsg := err.Error()
require.Contains(t, errMsg, "test.t.col1", "Error message should contain the actual column name")
require.NotContains(t, errMsg, "Column#", "Error message should not contain 'Column#'")
})
// Test case 2: Derived column - (col1 + col2) * constant
// This tests the case: (t5.col7 + t5.col2) * ABS(-9223372036854775807)
t.Run("DerivedColumn", func(t *testing.T) {
ctx := createContext(t)
// Create two columns with OrigName set (simulating real table columns)
col1 := &Column{
RetType: types.NewFieldType(mysql.TypeLonglong),
ID: 1,
UniqueID: 1,
Index: 0,
OrigName: "t5.col7",
}
col2 := &Column{
RetType: types.NewFieldType(mysql.TypeLonglong),
ID: 2,
UniqueID: 2,
Index: 1,
OrigName: "t5.col2",
}
// Build the addition function (col1 + col2) - this creates a derived column
addFunc, err := funcs[ast.Plus].getFunction(ctx, []Expression{col1, col2})
require.NoError(t, err)
require.NotNil(t, addFunc)
// Wrap the addition function as a ScalarFunction
addExpr := &ScalarFunction{
FuncName: ast.NewCIStr(ast.Plus),
RetType: types.NewFieldType(mysql.TypeLonglong),
Function: addFunc,
}
// Create a constant for a large value that will cause overflow when multiplied
constant := &Constant{
Value: types.NewIntDatum(math.MaxInt64),
RetType: types.NewFieldType(mysql.TypeLonglong),
}
// Build the multiply function: (col1 + col2) * constant
mulFunc, err := funcs[ast.Mul].getFunction(ctx, []Expression{addExpr, constant})
require.NoError(t, err)
require.NotNil(t, mulFunc)
// Create a mock chunk with values that will cause overflow
// col1 = 2, col2 = 1, so (col1 + col2) = 3, and 3 * MaxInt64 overflows
chk := chunk.NewChunkWithCapacity([]*types.FieldType{
types.NewFieldType(mysql.TypeLonglong),
types.NewFieldType(mysql.TypeLonglong),
}, 1)
chk.AppendInt64(0, 2)
chk.AppendInt64(1, 1)
row := chk.GetRow(0)
// Execute the function - should cause overflow
_, _, err = mulFunc.evalInt(ctx, row)
require.Error(t, err)
// The error message should contain the actual column names "t5.col7" and "t5.col2"
// instead of "Column#1" and "Column#2"
errMsg := err.Error()
require.Contains(t, errMsg, "t5.col7", "Error message should contain the actual column name t5.col7")
require.Contains(t, errMsg, "t5.col2", "Error message should contain the actual column name t5.col2")
require.NotContains(t, errMsg, "Column#", "Error message should not contain 'Column#'")
})
}