// Copyright 2017 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package expression import ( "testing" "time" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/stretchr/testify/require" ) func TestCompareFunctionWithRefine(t *testing.T) { ctx := createContext(t) tblInfo := newTestTableBuilder("").add("a", mysql.TypeLong, mysql.NotNullFlag).build() tests := []struct { exprStr string result string }{ {"a < '1.0'", "lt(a, 1)"}, {"a <= '1.0'", "le(a, 1)"}, {"a > '1'", "gt(a, 1)"}, {"a >= '1'", "ge(a, 1)"}, {"a = '1'", "eq(a, 1)"}, {"a <=> '1'", "nulleq(a, 1)"}, {"a != '1'", "ne(a, 1)"}, {"a < '1.1'", "lt(a, 2)"}, {"a <= '1.1'", "le(a, 1)"}, {"a > 1.1", "gt(a, 1)"}, {"a >= '1.1'", "ge(a, 2)"}, {"a = '1.1'", "0"}, {"a <=> '1.1'", "0"}, {"a != '1.1'", "ne(cast(a, double BINARY), 1.1)"}, {"'1' < a", "lt(1, a)"}, {"'1' <= a", "le(1, a)"}, {"'1' > a", "gt(1, a)"}, {"'1' >= a", "ge(1, a)"}, {"'1' = a", "eq(1, a)"}, {"'1' <=> a", "nulleq(1, a)"}, {"'1' != a", "ne(1, a)"}, {"'1.1' < a", "lt(1, a)"}, {"'1.1' <= a", "le(2, a)"}, {"'1.1' > a", "gt(2, a)"}, {"'1.1' >= a", "ge(1, a)"}, {"'1.1' = a", "0"}, {"'1.1' <=> a", "0"}, {"'1.1' != a", "ne(1.1, cast(a, double BINARY))"}, {"'123456789123456711111189' = a", "0"}, {"123456789123456789.12345 = a", "0"}, {"123456789123456789123456789.12345 > a", "1"}, {"-123456789123456789123456789.12345 > a", "0"}, {"123456789123456789123456789.12345 < a", "0"}, {"-123456789123456789123456789.12345 < a", "1"}, {"'aaaa'=a", "eq(0, a)"}, } for _, test := range tests { f, err := ParseSimpleExpr(ctx, test.exprStr, WithTableInfo("", tblInfo)) require.NoError(t, err) require.Equal(t, test.result, f.StringWithCtx(ctx, errors.RedactLogDisable)) } } func TestCompare(t *testing.T) { ctx := createContext(t) intVal, uintVal, realVal, stringVal, decimalVal := 1, uint64(1), 1.1, "123", types.NewDecFromFloatForTest(123.123) timeVal := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeDatetime, 6) durationVal := types.Duration{Duration: 12*time.Hour + 1*time.Minute + 1*time.Second} jsonVal := types.CreateBinaryJSON("123") // test cases for generating function signatures. tests := []struct { arg0 any arg1 any funcName string tp byte expected int64 }{ {intVal, intVal, ast.LT, mysql.TypeLonglong, 0}, {stringVal, stringVal, ast.LT, mysql.TypeVarString, 0}, {intVal, decimalVal, ast.LT, mysql.TypeNewDecimal, 1}, {realVal, decimalVal, ast.LT, mysql.TypeDouble, 1}, {durationVal, durationVal, ast.LT, mysql.TypeDuration, 0}, {realVal, realVal, ast.LT, mysql.TypeDouble, 0}, {intVal, intVal, ast.NullEQ, mysql.TypeLonglong, 1}, {decimalVal, decimalVal, ast.LE, mysql.TypeNewDecimal, 1}, {decimalVal, decimalVal, ast.GT, mysql.TypeNewDecimal, 0}, {decimalVal, decimalVal, ast.GE, mysql.TypeNewDecimal, 1}, {decimalVal, decimalVal, ast.NE, mysql.TypeNewDecimal, 0}, {decimalVal, decimalVal, ast.EQ, mysql.TypeNewDecimal, 1}, {decimalVal, decimalVal, ast.NullEQ, mysql.TypeNewDecimal, 1}, {durationVal, durationVal, ast.LE, mysql.TypeDuration, 1}, {durationVal, durationVal, ast.GT, mysql.TypeDuration, 0}, {durationVal, durationVal, ast.GE, mysql.TypeDuration, 1}, {durationVal, durationVal, ast.EQ, mysql.TypeDuration, 1}, {durationVal, durationVal, ast.NE, mysql.TypeDuration, 0}, {durationVal, durationVal, ast.NullEQ, mysql.TypeDuration, 1}, {nil, nil, ast.NullEQ, mysql.TypeNull, 1}, {nil, intVal, ast.NullEQ, mysql.TypeDouble, 0}, {uintVal, intVal, ast.NullEQ, mysql.TypeLonglong, 1}, {uintVal, intVal, ast.EQ, mysql.TypeLonglong, 1}, {intVal, uintVal, ast.NullEQ, mysql.TypeLonglong, 1}, {intVal, uintVal, ast.EQ, mysql.TypeLonglong, 1}, {timeVal, timeVal, ast.LT, mysql.TypeDatetime, 0}, {timeVal, timeVal, ast.LE, mysql.TypeDatetime, 1}, {timeVal, timeVal, ast.GT, mysql.TypeDatetime, 0}, {timeVal, timeVal, ast.GE, mysql.TypeDatetime, 1}, {timeVal, timeVal, ast.EQ, mysql.TypeDatetime, 1}, {timeVal, timeVal, ast.NE, mysql.TypeDatetime, 0}, {timeVal, timeVal, ast.NullEQ, mysql.TypeDatetime, 1}, {jsonVal, jsonVal, ast.LT, mysql.TypeJSON, 0}, {jsonVal, jsonVal, ast.LE, mysql.TypeJSON, 1}, {jsonVal, jsonVal, ast.GT, mysql.TypeJSON, 0}, {jsonVal, jsonVal, ast.GE, mysql.TypeJSON, 1}, {jsonVal, jsonVal, ast.NE, mysql.TypeJSON, 0}, {jsonVal, jsonVal, ast.EQ, mysql.TypeJSON, 1}, {jsonVal, jsonVal, ast.NullEQ, mysql.TypeJSON, 1}, } for _, test := range tests { bf, err := funcs[test.funcName].getFunction(ctx, primitiveValsToConstants(ctx, []any{test.arg0, test.arg1})) require.NoError(t, err) args := bf.getArgs() require.Equal(t, test.tp, args[0].GetType(ctx).GetType()) require.Equal(t, test.tp, args[1].GetType(ctx).GetType()) res, err := evalBuiltinFunc(bf, ctx, chunk.Row{}) require.NoError(t, err) require.False(t, res.IsNull()) require.Equal(t, types.KindInt64, res.Kind()) require.Equal(t, test.expected, res.GetInt64()) } // test decimalCol, stringCon := &Column{RetType: types.NewFieldType(mysql.TypeNewDecimal)}, &Constant{RetType: types.NewFieldType(mysql.TypeVarchar)} bf, err := funcs[ast.LT].getFunction(ctx, []Expression{decimalCol, stringCon}) require.NoError(t, err) args := bf.getArgs() require.Equal(t, mysql.TypeNewDecimal, args[0].GetType(ctx).GetType()) require.Equal(t, mysql.TypeNewDecimal, args[1].GetType(ctx).GetType()) // test