diff --git a/ast/functions.go b/ast/functions.go index 70db57e721..8e1f4ebd59 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -331,17 +331,22 @@ func (n *FuncTrimExpr) Accept(v Visitor) (Node, bool) { return n, false } n.Str = node.(ExprNode) - node, ok = n.RemStr.Accept(v) - if !ok { - return n, false + if n.RemStr != nil { + node, ok = n.RemStr.Accept(v) + if !ok { + return n, false + } + n.RemStr = node.(ExprNode) } - n.RemStr = node.(ExprNode) return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. func (n *FuncTrimExpr) IsStatic() bool { - return n.Str.IsStatic() && n.RemStr.IsStatic() + if n.RemStr != nil { + return n.Str.IsStatic() && n.RemStr.IsStatic() + } + return n.Str.IsStatic() } // DateArithType is type for DateArith type. diff --git a/optimizer/evaluator/evaluator.go b/optimizer/evaluator/evaluator.go index abe119fda3..f30e49b12d 100644 --- a/optimizer/evaluator/evaluator.go +++ b/optimizer/evaluator/evaluator.go @@ -55,6 +55,13 @@ func EvalBool(ctx context.Context, expr ast.ExprNode) (bool, error) { return i != 0, nil } +func boolToInt64(v bool) int64 { + if v { + return int64(1) + } + return int64(0) +} + // Evaluator is a ast Visitor that evaluates an expression. type Evaluator struct { ctx context.Context @@ -254,7 +261,7 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool { return false } if r == 0 { - n.SetValue(!n.Not) + n.SetValue(boolToInt64(!n.Not)) return true } } @@ -264,7 +271,7 @@ func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool { n.SetValue(nil) return true } - n.SetValue(n.Not) + n.SetValue(boolToInt64(n.Not)) return true } @@ -276,7 +283,7 @@ func (e *Evaluator) isNull(v *ast.IsNullExpr) bool { if v.Not { boolVal = !boolVal } - v.SetValue(boolVal) + v.SetValue(boolToInt64(boolVal)) return true } @@ -296,7 +303,7 @@ func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool { if v.Not { boolVal = !boolVal } - v.SetValue(boolVal) + v.SetValue(boolToInt64(boolVal)) return true } @@ -356,11 +363,7 @@ func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool { switch x := a.(type) { case nil: case bool: - if x { - u.SetValue(int64(1)) - } else { - u.SetValue(int64(0)) - } + u.SetValue(boolToInt64(x)) case float32: u.SetValue(+x) case float64: diff --git a/optimizer/evaluator/evaluator_like.go b/optimizer/evaluator/evaluator_like.go index e5507595af..212c48b688 100644 --- a/optimizer/evaluator/evaluator_like.go +++ b/optimizer/evaluator/evaluator_like.go @@ -158,7 +158,7 @@ func (e *Evaluator) patternLike(p *ast.PatternLikeExpr) bool { if p.Not { match = !match } - p.SetValue(match) + p.SetValue(boolToInt64(match)) return true } @@ -212,6 +212,6 @@ func (e *Evaluator) patternRegexp(p *ast.PatternRegexpExpr) bool { if p.Not { match = !match } - p.SetValue(match) + p.SetValue(boolToInt64(match)) return true } diff --git a/optimizer/evaluator/evaluator_test.go b/optimizer/evaluator/evaluator_test.go index 06d3c704e8..568a6008a1 100644 --- a/optimizer/evaluator/evaluator_test.go +++ b/optimizer/evaluator/evaluator_test.go @@ -16,6 +16,7 @@ package evaluator import ( "fmt" "testing" + "time" . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" @@ -24,7 +25,6 @@ import ( "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/types" - "time" ) var _ = Suite(&testEvaluatorSuite{}) @@ -58,7 +58,7 @@ func (s *testEvaluatorSuite) runTests(c *C, cases []testCase) { val, err := Eval(ctx, expr) c.Assert(err, IsNil) valStr := fmt.Sprintf("%v", val) - c.Assert(valStr, Equals, ca.resultStr) + c.Assert(valStr, Equals, ca.resultStr, Commentf("for %s", ca.exprStr)) } } @@ -559,3 +559,446 @@ func (s *testEvaluatorSuite) TestDateArith(c *C) { c.Assert(err, NotNil, Commentf("%s", v)) } } + +func (s *testEvaluatorSuite) TestExtract(c *C) { + str := "2011-11-11 10:10:10.123456" + tbl := []struct { + Unit string + Expect int64 + }{ + {"MICROSECOND", 123456}, + {"SECOND", 10}, + {"MINUTE", 10}, + {"HOUR", 10}, + {"DAY", 11}, + {"WEEK", 45}, + {"MONTH", 11}, + {"QUARTER", 4}, + {"YEAR", 2011}, + {"SECOND_MICROSECOND", 10123456}, + {"MINUTE_MICROSECOND", 1010123456}, + {"MINUTE_SECOND", 1010}, + {"HOUR_MICROSECOND", 101010123456}, + {"HOUR_SECOND", 101010}, + {"HOUR_MINUTE", 1010}, + {"DAY_MICROSECOND", 11101010123456}, + {"DAY_SECOND", 11101010}, + {"DAY_MINUTE", 111010}, + {"DAY_HOUR", 1110}, + {"YEAR_MONTH", 201111}, + } + ctx := mock.NewContext() + for _, t := range tbl { + e := &ast.FuncExtractExpr{ + Unit: t.Unit, + Date: ast.NewValueExpr(str), + } + + v, err := Eval(ctx, e) + c.Assert(err, IsNil) + c.Assert(v, Equals, t.Expect) + } + + // Test nil + e := &ast.FuncExtractExpr{ + Unit: "SECOND", + Date: ast.NewValueExpr(nil), + } + + v, err := Eval(ctx, e) + c.Assert(err, IsNil) + c.Assert(v, IsNil) +} + +func (s *testEvaluatorSuite) TestPatternIn(c *C) { + cases := []testCase{ + { + exprStr: "1 in (1, 2, 3)", + resultStr: "1", + }, + { + exprStr: "1 in (2, 3)", + resultStr: "0", + }, + { + exprStr: "NULL in (2, 3)", + resultStr: "", + }, + { + exprStr: "NULL not in (2, 3)", + resultStr: "", + }, { + exprStr: "NULL in (NULL, 3)", + resultStr: "", + }, + } + s.runTests(c, cases) +} + +func (s *testEvaluatorSuite) TestIsNull(c *C) { + cases := []testCase{ + { + exprStr: "1 IS NULL", + resultStr: "0", + }, + { + exprStr: "1 IS NOT NULL", + resultStr: "1", + }, + { + exprStr: "NULL IS NULL", + resultStr: "1", + }, + { + exprStr: "NULL IS NOT NULL", + resultStr: "0", + }, + } + s.runTests(c, cases) +} + +func (s *testEvaluatorSuite) TestIsTruth(c *C) { + cases := []testCase{ + { + exprStr: "1 IS TRUE", + resultStr: "1", + }, + { + exprStr: "2 IS TRUE", + resultStr: "1", + }, + { + exprStr: "0 IS TRUE", + resultStr: "0", + }, + { + exprStr: "NULL IS TRUE", + resultStr: "0", + }, + { + exprStr: "1 IS FALSE", + resultStr: "0", + }, + { + exprStr: "2 IS FALSE", + resultStr: "0", + }, + { + exprStr: "0 IS FALSE", + resultStr: "1", + }, + { + exprStr: "NULL IS NOT FALSE", + resultStr: "1", + }, + { + exprStr: "1 IS NOT TRUE", + resultStr: "0", + }, + { + exprStr: "2 IS NOT TRUE", + resultStr: "0", + }, + { + exprStr: "0 IS NOT TRUE", + resultStr: "1", + }, + { + exprStr: "NULL IS NOT TRUE", + resultStr: "1", + }, + { + exprStr: "1 IS NOT FALSE", + resultStr: "1", + }, + { + exprStr: "2 IS NOT FALSE", + resultStr: "1", + }, + { + exprStr: "0 IS NOT FALSE", + resultStr: "0", + }, + { + exprStr: "NULL IS NOT FALSE", + resultStr: "1", + }, + } + s.runTests(c, cases) +} + +func (s *testEvaluatorSuite) TestLike(c *C) { + tbl := []struct { + pattern string + input string + escape byte + match bool + }{ + {"", "a", '\\', false}, + {"a", "a", '\\', true}, + {"a", "b", '\\', false}, + {"aA", "aA", '\\', true}, + {"_", "a", '\\', true}, + {"_", "ab", '\\', false}, + {"__", "b", '\\', false}, + {"_ab", "AAB", '\\', true}, + {"%", "abcd", '\\', true}, + {"%", "", '\\', true}, + {"%a", "AAA", '\\', true}, + {"%b", "AAA", '\\', false}, + {"b%", "BBB", '\\', true}, + {"%a%", "BBB", '\\', false}, + {"%a%", "BAB", '\\', true}, + {"a%", "BBB", '\\', false}, + {`\%a`, `%a`, '\\', true}, + {`\%a`, `aa`, '\\', false}, + {`\_a`, `_a`, '\\', true}, + {`\_a`, `aa`, '\\', false}, + {`\\_a`, `\xa`, '\\', true}, + {`\a\b`, `\a\b`, '\\', true}, + {"%%_", `abc`, '\\', true}, + {`+_a`, `_a`, '+', true}, + {`+%a`, `%a`, '+', true}, + {`\%a`, `%a`, '+', false}, + {`++a`, `+a`, '+', true}, + {`++_a`, `+xa`, '+', true}, + } + for _, v := range tbl { + patChars, patTypes := compilePattern(v.pattern, v.escape) + match := doMatch(v.input, patChars, patTypes) + c.Assert(match, Equals, v.match, Commentf("%v", v)) + } + cases := []testCase{ + { + exprStr: "'a' LIKE ''", + resultStr: "0", + }, + { + exprStr: "'a' LIKE 'a'", + resultStr: "1", + }, + { + exprStr: "'a' LIKE 'b'", + resultStr: "0", + }, + { + exprStr: "'aA' LIKE 'Aa'", + resultStr: "1", + }, + { + exprStr: "'aAb' LIKE 'Aa%'", + resultStr: "1", + }, + { + exprStr: "'aAb' LIKE 'Aa_'", + resultStr: "1", + }, + } + s.runTests(c, cases) +} + +func (s *testEvaluatorSuite) TestRegexp(c *C) { + tbl := []struct { + pattern string + input string + match int64 + }{ + {"^$", "a", 0}, + {"a", "a", 1}, + {"a", "b", 0}, + {"aA", "aA", 1}, + {".", "a", 1}, + {"^.$", "ab", 0}, + {"..", "b", 0}, + {".ab", "aab", 1}, + {".*", "abcd", 1}, + } + ctx := mock.NewContext() + for _, v := range tbl { + pattern := &ast.PatternRegexpExpr{ + Pattern: ast.NewValueExpr(v.pattern), + Expr: ast.NewValueExpr(v.input), + } + match, err := Eval(ctx, pattern) + c.Assert(err, IsNil) + c.Assert(match, Equals, v.match, Commentf("%v", v)) + } +} + +func (s *testEvaluatorSuite) TestSubstring(c *C) { + tbl := []struct { + str string + pos int64 + slen int64 + result string + }{ + {"Quadratically", 5, -1, "ratically"}, + {"foobarbar", 4, -1, "barbar"}, + {"Quadratically", 5, 6, "ratica"}, + {"Sakila", -3, -1, "ila"}, + {"Sakila", -5, 3, "aki"}, + {"Sakila", -4, 2, "ki"}, + {"Sakila", 1000, 2, ""}, + {"", 2, 3, ""}, + } + ctx := mock.NewContext() + for _, v := range tbl { + f := &ast.FuncSubstringExpr{ + StrExpr: ast.NewValueExpr(v.str), + Pos: ast.NewValueExpr(v.pos), + } + if v.slen != -1 { + f.Len = ast.NewValueExpr(v.slen) + } + c.Assert(f.IsStatic(), Equals, true) + + r, err := Eval(ctx, f) + c.Assert(err, IsNil) + s, ok := r.(string) + c.Assert(ok, Equals, true) + c.Assert(s, Equals, v.result) + + r1, err := Eval(ctx, f) + c.Assert(err, IsNil) + s1, ok := r1.(string) + c.Assert(ok, Equals, true) + c.Assert(s, Equals, s1) + } + errTbl := []struct { + str interface{} + pos interface{} + len interface{} + result string + }{ + {1, 5, -1, "ratically"}, + {"foobarbar", "4", -1, "barbar"}, + {"Quadratically", 5, "6", "ratica"}, + } + for _, v := range errTbl { + f := &ast.FuncSubstringExpr{ + StrExpr: ast.NewValueExpr(v.str), + Pos: ast.NewValueExpr(v.pos), + } + if v.len != -1 { + f.Len = ast.NewValueExpr(v.len) + } + _, err := Eval(ctx, f) + c.Assert(err, NotNil) + } +} + +func (s *testEvaluatorSuite) TestTrim(c *C) { + tbl := []struct { + str interface{} + remstr interface{} + dir ast.TrimDirectionType + result interface{} + }{ + {" bar ", nil, ast.TrimBothDefault, "bar"}, + {"xxxbarxxx", "x", ast.TrimLeading, "barxxx"}, + {"xxxbarxxx", "x", ast.TrimBoth, "bar"}, + {"barxxyz", "xyz", ast.TrimTrailing, "barx"}, + {nil, "xyz", ast.TrimBoth, nil}, + {1, 2, ast.TrimBoth, "1"}, + {" \t\rbar\n ", nil, ast.TrimBothDefault, "bar"}, + } + ctx := mock.NewContext() + for _, v := range tbl { + f := &ast.FuncTrimExpr{ + Str: ast.NewValueExpr(v.str), + Direction: v.dir, + } + if v.remstr != nil { + f.RemStr = ast.NewValueExpr(v.remstr) + } + c.Assert(f.IsStatic(), Equals, true) + + r, err := Eval(ctx, f) + c.Assert(err, IsNil) + c.Assert(r, Equals, v.result) + } +} + +func (s *testEvaluatorSuite) TestUnaryOp(c *C) { + tbl := []struct { + arg interface{} + op opcode.Op + result interface{} + }{ + // test NOT. + {1, opcode.Not, int64(0)}, + {0, opcode.Not, int64(1)}, + {nil, opcode.Not, nil}, + {mysql.Hex{Value: 0}, opcode.Not, int64(1)}, + {mysql.Bit{Value: 0, Width: 1}, opcode.Not, int64(1)}, + {mysql.Enum{Name: "a", Value: 1}, opcode.Not, int64(0)}, + {mysql.Set{Name: "a", Value: 1}, opcode.Not, int64(0)}, + + // test BitNeg. + {nil, opcode.BitNeg, nil}, + {-1, opcode.BitNeg, uint64(0)}, + + // test Plus. + {nil, opcode.Plus, nil}, + {float64(1.0), opcode.Plus, float64(1.0)}, + {int(1), opcode.Plus, int(1)}, + {int64(1), opcode.Plus, int64(1)}, + {uint64(1), opcode.Plus, uint64(1)}, + {"1.0", opcode.Plus, "1.0"}, + {[]byte("1.0"), opcode.Plus, []byte("1.0")}, + {mysql.Hex{Value: 1}, opcode.Plus, mysql.Hex{Value: 1}}, + {mysql.Bit{Value: 1, Width: 1}, opcode.Plus, mysql.Bit{Value: 1, Width: 1}}, + {true, opcode.Plus, int64(1)}, + {false, opcode.Plus, int64(0)}, + {mysql.Enum{Name: "a", Value: 1}, opcode.Plus, mysql.Enum{Name: "a", Value: 1}}, + {mysql.Set{Name: "a", Value: 1}, opcode.Plus, mysql.Set{Name: "a", Value: 1}}, + + // test Minus. + {nil, opcode.Minus, nil}, + {float64(1.0), opcode.Minus, float64(-1.0)}, + {int(1), opcode.Minus, int(-1)}, + {int64(1), opcode.Minus, int64(-1)}, + {uint64(1), opcode.Minus, -int64(1)}, + {"1.0", opcode.Minus, -1.0}, + {[]byte("1.0"), opcode.Minus, -1.0}, + {mysql.Hex{Value: 1}, opcode.Minus, -1.0}, + {mysql.Bit{Value: 1, Width: 1}, opcode.Minus, -1.0}, + {true, opcode.Minus, int64(-1)}, + {false, opcode.Minus, int64(0)}, + {mysql.Enum{Name: "a", Value: 1}, opcode.Minus, -1.0}, + {mysql.Set{Name: "a", Value: 1}, opcode.Minus, -1.0}, + } + ctx := mock.NewContext() + for _, t := range tbl { + expr := &ast.UnaryOperationExpr{Op: t.op, V: ast.NewValueExpr(t.arg)} + result, err := Eval(ctx, expr) + c.Assert(err, IsNil) + c.Assert(result, DeepEquals, t.result) + } + + tbl = []struct { + arg interface{} + op opcode.Op + result interface{} + }{ + {mysql.NewDecimalFromInt(1, 0), opcode.Plus, mysql.NewDecimalFromInt(1, 0)}, + {mysql.Duration{Duration: time.Duration(838*3600 + 59*60 + 59), Fsp: mysql.DefaultFsp}, opcode.Plus, + mysql.Duration{Duration: time.Duration(838*3600 + 59*60 + 59), Fsp: mysql.DefaultFsp}}, + {mysql.Time{Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), Type: mysql.TypeDatetime, Fsp: 0}, opcode.Plus, mysql.Time{Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), Type: mysql.TypeDatetime, Fsp: 0}}, + + {mysql.NewDecimalFromInt(1, 0), opcode.Minus, mysql.NewDecimalFromInt(-1, 0)}, + {mysql.ZeroDuration, opcode.Minus, mysql.NewDecimalFromInt(0, 0)}, + {mysql.Time{Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), Type: mysql.TypeDatetime, Fsp: 0}, opcode.Minus, mysql.NewDecimalFromInt(-20091110230000, 0)}, + } + + for _, t := range tbl { + expr := &ast.UnaryOperationExpr{Op: t.op, V: ast.NewValueExpr(t.arg)} + + result, err := Eval(ctx, expr) + c.Assert(err, IsNil) + + ret, err := types.Compare(result, t.result) + c.Assert(err, IsNil) + c.Assert(ret, Equals, 0) + } +} diff --git a/optimizer/plan/plan_test.go b/optimizer/plan/plan_test.go index e919c15186..0bf285cc2c 100644 --- a/optimizer/plan/plan_test.go +++ b/optimizer/plan/plan_test.go @@ -133,7 +133,7 @@ func (s *testPlanSuite) TestRangeBuilder(c *C) { resultStr: "[[abc abd)]", }, { - exprStr: "a LIKE 'abc.'", + exprStr: "a LIKE 'abc_'", resultStr: "[(abc abd)]", }, { diff --git a/optimizer/plan/range.go b/optimizer/plan/range.go index c339a0952c..78d2938970 100644 --- a/optimizer/plan/range.go +++ b/optimizer/plan/range.go @@ -128,7 +128,6 @@ func (r *rangeBuilder) build(expr ast.ExprNode) []rangePoint { case *ast.PatternLikeExpr: return r.buildFromPatternLike(x) case *ast.ColumnNameExpr: - fmt.Println("build column name") return r.buildFromColumnName(x) } return fullRange @@ -273,18 +272,18 @@ func (r *rangeBuilder) buildFromPatternLike(x *ast.PatternLikeExpr) []rangePoint // unscape the pattern var exclude bool for i := 0; i < len(pattern); i++ { - if pattern[i] == '\\' { + if pattern[i] == x.Escape { i++ if i < len(pattern) { lowValue = append(lowValue, pattern[i]) } else { - lowValue = append(lowValue, '\\') + lowValue = append(lowValue, x.Escape) } continue } if pattern[i] == '%' { break - } else if pattern[i] == '.' { + } else if pattern[i] == '_' { exclude = true break }