Files
tidb/executor/aggfuncs/func_count.go

389 lines
9.0 KiB
Go

package aggfuncs
import (
"encoding/binary"
"unsafe"
"github.com/juju/errors"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
)
type baseCount struct {
baseAggFunc
}
type partialResult4Count = int64
func (e *baseCount) AllocPartialResult() PartialResult {
return PartialResult(new(partialResult4Count))
}
func (e *baseCount) ResetPartialResult(pr PartialResult) {
p := (*partialResult4Count)(pr)
*p = 0
}
func (e *baseCount) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4Count)(pr)
chk.AppendInt64(e.ordinal, *p)
return nil
}
type countOriginal4Int struct {
baseCount
}
func (e *countOriginal4Int) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4Count)(pr)
for _, row := range rowsInGroup {
_, isNull, err := e.args[0].EvalInt(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}
*p++
}
return nil
}
type countOriginal4Real struct {
baseCount
}
func (e *countOriginal4Real) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4Count)(pr)
for _, row := range rowsInGroup {
_, isNull, err := e.args[0].EvalReal(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}
*p++
}
return nil
}
type countOriginal4Decimal struct {
baseCount
}
func (e *countOriginal4Decimal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4Count)(pr)
for _, row := range rowsInGroup {
_, isNull, err := e.args[0].EvalDecimal(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}
*p++
}
return nil
}
type countOriginal4Time struct {
baseCount
}
func (e *countOriginal4Time) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4Count)(pr)
for _, row := range rowsInGroup {
_, isNull, err := e.args[0].EvalTime(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}
*p++
}
return nil
}
type countOriginal4Duration struct {
baseCount
}
func (e *countOriginal4Duration) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4Count)(pr)
for _, row := range rowsInGroup {
_, isNull, err := e.args[0].EvalDuration(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}
*p++
}
return nil
}
type countOriginal4JSON struct {
baseCount
}
func (e *countOriginal4JSON) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4Count)(pr)
for _, row := range rowsInGroup {
_, isNull, err := e.args[0].EvalJSON(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}
*p++
}
return nil
}
type countOriginal4String struct {
baseCount
}
func (e *countOriginal4String) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4Count)(pr)
for _, row := range rowsInGroup {
_, isNull, err := e.args[0].EvalString(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}
*p++
}
return nil
}
type countPartial struct {
baseCount
}
func (e *countPartial) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4Count)(pr)
for _, row := range rowsInGroup {
input, isNull, err := e.args[0].EvalInt(sctx, row)
if err != nil {
return errors.Trace(err)
}
if isNull {
continue
}
*p += input
}
return nil
}
func (*countPartial) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) error {
p1, p2 := (*partialResult4Count)(src), (*partialResult4Count)(dst)
*p2 += *p1
return nil
}
type countOriginalWithDistinct struct {
baseCount
}
type partialResult4CountWithDistinct struct {
count int64
valSet stringSet
}
func (e *countOriginalWithDistinct) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4CountWithDistinct{
count: 0,
valSet: newStringSet(),
})
}
func (e *countOriginalWithDistinct) ResetPartialResult(pr PartialResult) {
p := (*partialResult4CountWithDistinct)(pr)
p.count = 0
p.valSet = newStringSet()
}
func (e *countOriginalWithDistinct) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4CountWithDistinct)(pr)
chk.AppendInt64(e.ordinal, p.count)
return nil
}
func (e *countOriginalWithDistinct) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (err error) {
p := (*partialResult4CountWithDistinct)(pr)
hasNull, isNull := false, false
encodedBytes := make([]byte, 0)
// Decimal struct is the biggest type we will use.
buf := make([]byte, types.MyDecimalStructSize)
for _, row := range rowsInGroup {
hasNull, isNull = false, false
encodedBytes = encodedBytes[:0]
for i := 0; i < len(e.args) && !hasNull; i++ {
encodedBytes, isNull, err = e.evalAndEncode(sctx, e.args[i], row, buf, encodedBytes)
if err != nil {
return
}
if isNull {
hasNull = true
break
}
}
encodedString := string(encodedBytes)
if hasNull || p.valSet.exist(encodedString) {
continue
}
p.valSet.insert(encodedString)
p.count++
}
return nil
}
// evalAndEncode eval one row with an expression and encode value to bytes.
func (e *countOriginalWithDistinct) evalAndEncode(
sctx sessionctx.Context, arg expression.Expression,
row chunk.Row, buf, encodedBytes []byte,
) ([]byte, bool, error) {
switch tp := arg.GetType().EvalType(); tp {
case types.ETInt:
val, isNull, err := arg.EvalInt(sctx, row)
if err != nil || isNull {
return encodedBytes, isNull, errors.Trace(err)
}
encodedBytes = appendInt64(encodedBytes, buf, val)
case types.ETReal:
val, isNull, err := arg.EvalReal(sctx, row)
if err != nil || isNull {
return encodedBytes, isNull, errors.Trace(err)
}
encodedBytes = appendFloat64(encodedBytes, buf, val)
case types.ETDecimal:
val, isNull, err := arg.EvalDecimal(sctx, row)
if err != nil || isNull {
return encodedBytes, isNull, errors.Trace(err)
}
encodedBytes = appendDecimal(encodedBytes, buf, val)
case types.ETTimestamp, types.ETDatetime:
val, isNull, err := arg.EvalTime(sctx, row)
if err != nil || isNull {
return encodedBytes, isNull, errors.Trace(err)
}
encodedBytes = appendTime(encodedBytes, buf, val)
case types.ETDuration:
val, isNull, err := arg.EvalDuration(sctx, row)
if err != nil || isNull {
return encodedBytes, isNull, errors.Trace(err)
}
encodedBytes = appendDuration(encodedBytes, buf, val)
case types.ETJson:
val, isNull, err := arg.EvalJSON(sctx, row)
if err != nil || isNull {
return encodedBytes, isNull, errors.Trace(err)
}
encodedBytes = appendJSON(encodedBytes, buf, val)
case types.ETString:
val, isNull, err := arg.EvalString(sctx, row)
if err != nil || isNull {
return encodedBytes, isNull, errors.Trace(err)
}
encodedBytes = appendString(encodedBytes, buf, val)
default:
return nil, false, errors.Errorf("unsupported column type for encode %d", tp)
}
return encodedBytes, false, nil
}
func appendInt64(encodedBytes, buf []byte, val int64) []byte {
*(*int64)(unsafe.Pointer(&buf[0])) = val
buf = buf[:8]
encodedBytes = append(encodedBytes, buf...)
return encodedBytes
}
func appendFloat64(encodedBytes, buf []byte, val float64) []byte {
*(*float64)(unsafe.Pointer(&buf[0])) = val
buf = buf[:8]
encodedBytes = append(encodedBytes, buf...)
return encodedBytes
}
func appendDecimal(encodedBytes, buf []byte, val *types.MyDecimal) []byte {
*(*types.MyDecimal)(unsafe.Pointer(&buf[0])) = *val
buf = buf[:types.MyDecimalStructSize]
encodedBytes = append(encodedBytes, buf...)
return encodedBytes
}
func writeTime(buf []byte, t types.Time) {
binary.BigEndian.PutUint16(buf, uint16(t.Time.Year()))
buf[2] = uint8(t.Time.Month())
buf[3] = uint8(t.Time.Day())
buf[4] = uint8(t.Time.Hour())
buf[5] = uint8(t.Time.Minute())
buf[6] = uint8(t.Time.Second())
binary.BigEndian.PutUint32(buf[8:], uint32(t.Time.Microsecond()))
buf[12] = t.Type
buf[13] = uint8(t.Fsp)
}
func appendTime(encodedBytes, buf []byte, val types.Time) []byte {
writeTime(buf, val)
buf = buf[:16]
encodedBytes = append(encodedBytes, buf...)
return encodedBytes
}
func appendDuration(encodedBytes, buf []byte, val types.Duration) []byte {
*(*types.Duration)(unsafe.Pointer(&buf[0])) = val
buf = buf[:16]
encodedBytes = append(encodedBytes, buf...)
return encodedBytes
}
func appendJSON(encodedBytes, _ []byte, val json.BinaryJSON) []byte {
encodedBytes = append(encodedBytes, val.TypeCode)
encodedBytes = append(encodedBytes, val.Value...)
return encodedBytes
}
func appendString(encodedBytes, _ []byte, val string) []byte {
encodedBytes = append(encodedBytes, val...)
return encodedBytes
}