// Copyright 2017 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package expression import ( "testing" "time" "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" "github.com/stretchr/testify/require" ) func TestCompareFunctionWithRefine(t *testing.T) { ctx := createContext(t) tblInfo := newTestTableBuilder("").add("a", mysql.TypeLong, mysql.NotNullFlag).build() tests := []struct { exprStr string result string }{ {"a < '1.0'", "lt(a, 1)"}, {"a <= '1.0'", "le(a, 1)"}, {"a > '1'", "gt(a, 1)"}, {"a >= '1'", "ge(a, 1)"}, {"a = '1'", "eq(a, 1)"}, {"a <=> '1'", "nulleq(a, 1)"}, {"a != '1'", "ne(a, 1)"}, {"a < '1.1'", "lt(a, 2)"}, {"a <= '1.1'", "le(a, 1)"}, {"a > 1.1", "gt(a, 1)"}, {"a >= '1.1'", "ge(a, 2)"}, {"a = '1.1'", "0"}, {"a <=> '1.1'", "0"}, {"a != '1.1'", "ne(cast(a, double BINARY), 1.1)"}, {"'1' < a", "lt(1, a)"}, {"'1' <= a", "le(1, a)"}, {"'1' > a", "gt(1, a)"}, {"'1' >= a", "ge(1, a)"}, {"'1' = a", "eq(1, a)"}, {"'1' <=> a", "nulleq(1, a)"}, {"'1' != a", "ne(1, a)"}, {"'1.1' < a", "lt(1, a)"}, {"'1.1' <= a", "le(2, a)"}, {"'1.1' > a", "gt(2, a)"}, {"'1.1' >= a", "ge(1, a)"}, {"'1.1' = a", "0"}, {"'1.1' <=> a", "0"}, {"'1.1' != a", "ne(1.1, cast(a, double BINARY))"}, {"'123456789123456711111189' = a", "0"}, {"123456789123456789.12345 = a", "0"}, {"123456789123456789123456789.12345 > a", "1"}, {"-123456789123456789123456789.12345 > a", "0"}, {"123456789123456789123456789.12345 < a", "0"}, {"-123456789123456789123456789.12345 < a", "1"}, {"'aaaa'=a", "eq(0, a)"}, } cols, names, err := ColumnInfos2ColumnsAndNames(ctx, model.NewCIStr(""), tblInfo.Name, tblInfo.Cols(), tblInfo) require.NoError(t, err) schema := NewSchema(cols...) for _, test := range tests { f, err := ParseSimpleExprsWithNames(ctx, test.exprStr, schema, names) require.NoError(t, err) require.Equal(t, test.result, f[0].String()) } } func TestCompare(t *testing.T) { ctx := createContext(t) intVal, uintVal, realVal, stringVal, decimalVal := 1, uint64(1), 1.1, "123", types.NewDecFromFloatForTest(123.123) timeVal := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeDatetime, 6) durationVal := types.Duration{Duration: 12*time.Hour + 1*time.Minute + 1*time.Second} jsonVal := json.CreateBinary("123") // test cases for generating function signatures. tests := []struct { arg0 interface{} arg1 interface{} funcName string tp byte expected int64 }{ {intVal, intVal, ast.LT, mysql.TypeLonglong, 0}, {stringVal, stringVal, ast.LT, mysql.TypeVarString, 0}, {intVal, decimalVal, ast.LT, mysql.TypeNewDecimal, 1}, {realVal, decimalVal, ast.LT, mysql.TypeDouble, 1}, {durationVal, durationVal, ast.LT, mysql.TypeDuration, 0}, {realVal, realVal, ast.LT, mysql.TypeDouble, 0}, {intVal, intVal, ast.NullEQ, mysql.TypeLonglong, 1}, {decimalVal, decimalVal, ast.LE, mysql.TypeNewDecimal, 1}, {decimalVal, decimalVal, ast.GT, mysql.TypeNewDecimal, 0}, {decimalVal, decimalVal, ast.GE, mysql.TypeNewDecimal, 1}, {decimalVal, decimalVal, ast.NE, mysql.TypeNewDecimal, 0}, {decimalVal, decimalVal, ast.EQ, mysql.TypeNewDecimal, 1}, {decimalVal, decimalVal, ast.NullEQ, mysql.TypeNewDecimal, 1}, {durationVal, durationVal, ast.LE, mysql.TypeDuration, 1}, {durationVal, durationVal, ast.GT, mysql.TypeDuration, 0}, {durationVal, durationVal, ast.GE, mysql.TypeDuration, 1}, {durationVal, durationVal, ast.EQ, mysql.TypeDuration, 1}, {durationVal, durationVal, ast.NE, mysql.TypeDuration, 0}, {durationVal, durationVal, ast.NullEQ, mysql.TypeDuration, 1}, {nil, nil, ast.NullEQ, mysql.TypeNull, 1}, {nil, intVal, ast.NullEQ, mysql.TypeDouble, 0}, {uintVal, intVal, ast.NullEQ, mysql.TypeLonglong, 1}, {uintVal, intVal, ast.EQ, mysql.TypeLonglong, 1}, {intVal, uintVal, ast.NullEQ, mysql.TypeLonglong, 1}, {intVal, uintVal, ast.EQ, mysql.TypeLonglong, 1}, {timeVal, timeVal, ast.LT, mysql.TypeDatetime, 0}, {timeVal, timeVal, ast.LE, mysql.TypeDatetime, 1}, {timeVal, timeVal, ast.GT, mysql.TypeDatetime, 0}, {timeVal, timeVal, ast.GE, mysql.TypeDatetime, 1}, {timeVal, timeVal, ast.EQ, mysql.TypeDatetime, 1}, {timeVal, timeVal, ast.NE, mysql.TypeDatetime, 0}, {timeVal, timeVal, ast.NullEQ, mysql.TypeDatetime, 1}, {jsonVal, jsonVal, ast.LT, mysql.TypeJSON, 0}, {jsonVal, jsonVal, ast.LE, mysql.TypeJSON, 1}, {jsonVal, jsonVal, ast.GT, mysql.TypeJSON, 0}, {jsonVal, jsonVal, ast.GE, mysql.TypeJSON, 1}, {jsonVal, jsonVal, ast.NE, mysql.TypeJSON, 0}, {jsonVal, jsonVal, ast.EQ, mysql.TypeJSON, 1}, {jsonVal, jsonVal, ast.NullEQ, mysql.TypeJSON, 1}, } for _, test := range tests { bf, err := funcs[test.funcName].getFunction(ctx, primitiveValsToConstants(ctx, []interface{}{test.arg0, test.arg1})) require.NoError(t, err) args := bf.getArgs() require.Equal(t, test.tp, args[0].GetType().Tp) require.Equal(t, test.tp, args[1].GetType().Tp) res, isNil, err := bf.evalInt(chunk.Row{}) require.NoError(t, err) require.False(t, isNil) require.Equal(t, test.expected, res) } // test decimalCol, stringCon := &Column{RetType: types.NewFieldType(mysql.TypeNewDecimal)}, &Constant{RetType: types.NewFieldType(mysql.TypeVarchar)} bf, err := funcs[ast.LT].getFunction(ctx, []Expression{decimalCol, stringCon}) require.NoError(t, err) args := bf.getArgs() require.Equal(t, mysql.TypeNewDecimal, args[0].GetType().Tp) require.Equal(t, mysql.TypeNewDecimal, args[1].GetType().Tp) // test