Files
tidb/expression/util.go

150 lines
3.8 KiB
Go

// Copyright 2016 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"unicode"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/types"
)
// ExtractColumns extracts all columns from an expression.
func ExtractColumns(expr Expression) (cols []*Column) {
switch v := expr.(type) {
case *Column:
return []*Column{v}
case *ScalarFunction:
for _, arg := range v.GetArgs() {
cols = append(cols, ExtractColumns(arg)...)
}
}
return
}
// ColumnSubstitute substitutes the columns in filter to expressions in select fields.
// e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k.
func ColumnSubstitute(expr Expression, schema Schema, newExprs []Expression) Expression {
switch v := expr.(type) {
case *Column:
id := schema.GetColumnIndex(v)
if id == -1 {
return v
}
return newExprs[id].Clone()
case *ScalarFunction:
if v.FuncName.L == ast.Cast {
newFunc := v.Clone().(*ScalarFunction)
newFunc.GetArgs()[0] = ColumnSubstitute(newFunc.GetArgs()[0], schema, newExprs)
return newFunc
}
newArgs := make([]Expression, 0, len(v.GetArgs()))
for _, arg := range v.GetArgs() {
newArgs = append(newArgs, ColumnSubstitute(arg, schema, newExprs))
}
fun, _ := NewFunction(v.FuncName.L, v.RetType, newArgs...)
return fun
}
return expr
}
func datumsToConstants(datums []types.Datum) []Expression {
constants := make([]Expression, 0, len(datums))
for _, d := range datums {
constants = append(constants, &Constant{Value: d})
}
return constants
}
// calculateSum adds v to sum.
func calculateSum(sc *variable.StatementContext, sum, v types.Datum) (data types.Datum, err error) {
// for avg and sum calculation
// avg and sum use decimal for integer and decimal type, use float for others
// see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html
switch v.Kind() {
case types.KindNull:
case types.KindInt64, types.KindUint64:
var d *types.MyDecimal
d, err = v.ToDecimal(sc)
if err == nil {
data = types.NewDecimalDatum(d)
}
case types.KindMysqlDecimal:
data = v
default:
var f float64
f, err = v.ToFloat64(sc)
if err == nil {
data = types.NewFloat64Datum(f)
}
}
if err != nil {
return data, errors.Trace(err)
}
if data.IsNull() {
return sum, nil
}
switch sum.Kind() {
case types.KindNull:
return data, nil
case types.KindFloat64, types.KindMysqlDecimal:
return types.ComputePlus(sum, data)
default:
return data, errors.Errorf("invalid value %v for aggregate", sum.Kind())
}
}
// getValidPrefix gets a prefix of string which can parsed to a number with base. the minimun base is 2 and the maximum is 36.
func getValidPrefix(s string, base int64) string {
var (
validLen int
upper rune
)
switch {
case base >= 2 && base <= 9:
upper = rune('0' + base)
case base <= 36:
upper = rune('A' + base - 10)
default:
return ""
}
Loop:
for i := 0; i < len(s); i++ {
c := rune(s[i])
switch {
case unicode.IsDigit(c) || unicode.IsLower(c) || unicode.IsUpper(c):
c = unicode.ToUpper(c)
if c < upper {
validLen = i + 1
} else {
break Loop
}
case c == '+' || c == '-':
if i != 0 {
break Loop
}
default:
break Loop
}
}
if validLen > 1 && s[0] == '+' {
return s[1:validLen]
}
return s[:validLen]
}