// Copyright 2023 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vecgroupchecker import ( "bytes" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/codec" ) // VecGroupChecker is used to split a given chunk according to the `group by` expression in a vectorized manner // It is usually used for streamAgg type VecGroupChecker struct { ctx expression.EvalContext releaseBuffer func(buf *chunk.Column) // set these functions for testing allocateBuffer func(evalType types.EvalType, capacity int) (*chunk.Column, error) lastRowDatums []types.Datum // lastGroupKeyOfPrevChk is the groupKey of the last group of the previous chunk lastGroupKeyOfPrevChk []byte // firstGroupKey and lastGroupKey are used to store the groupKey of the first and last group of the current chunk firstGroupKey []byte lastGroupKey []byte // firstRowDatums and lastRowDatums store the results of the expression evaluation // for the first and last rows of the current chunk in datum // They are used to encode to get firstGroupKey and lastGroupKey firstRowDatums []types.Datum // sameGroup is used to check whether the current row belongs to the same group as the previous row sameGroup []bool // groupOffset holds the offset of the last row in each group of the current chunk groupOffset []int GroupByItems []expression.Expression // nextGroupID records the group id of the next group to be consumed nextGroupID int // groupCount is the count of groups in the current chunk groupCount int // vecEnabled indicates whether to use vectorized evaluation or not. vecEnabled bool } // NewVecGroupChecker creates a new VecGroupChecker func NewVecGroupChecker(ctx expression.EvalContext, vecEnabled bool, items []expression.Expression) *VecGroupChecker { return &VecGroupChecker{ ctx: ctx, vecEnabled: vecEnabled, GroupByItems: items, groupCount: 0, nextGroupID: 0, sameGroup: make([]bool, 1024), } } // SplitIntoGroups splits a chunk into multiple groups which the row in the same group have the same groupKey // `isFirstGroupSameAsPrev` indicates whether the groupKey of the first group of the newly passed chunk is equal to the groupKey of the last group left before // TODO: Since all the group by items are only a column reference, guaranteed by building projection below aggregation, we can directly compare data in a chunk. func (e *VecGroupChecker) SplitIntoGroups(chk *chunk.Chunk) (isFirstGroupSameAsPrev bool, err error) { // The numRows can not be zero. `fetchChild` is called before `splitIntoGroups` is called. // if numRows == 0, it will be returned in `fetchChild`. See `fetchChild` for more details. numRows := chk.NumRows() e.Reset() e.nextGroupID = 0 if len(e.GroupByItems) == 0 { e.groupOffset = append(e.groupOffset, numRows) e.groupCount = 1 return true, nil } for _, item := range e.GroupByItems { err = e.getFirstAndLastRowDatum(item, chk, numRows) if err != nil { return false, err } } ec := e.ctx.ErrCtx() e.firstGroupKey, err = codec.EncodeKey(e.ctx.Location(), e.firstGroupKey, e.firstRowDatums...) err = ec.HandleError(err) if err != nil { return false, err } e.lastGroupKey, err = codec.EncodeKey(e.ctx.Location(), e.lastGroupKey, e.lastRowDatums...) err = ec.HandleError(err) if err != nil { return false, err } if len(e.lastGroupKeyOfPrevChk) == 0 { isFirstGroupSameAsPrev = false } else { if bytes.Equal(e.lastGroupKeyOfPrevChk, e.firstGroupKey) { isFirstGroupSameAsPrev = true } else { isFirstGroupSameAsPrev = false } } if length := len(e.lastGroupKey); len(e.lastGroupKeyOfPrevChk) >= length { e.lastGroupKeyOfPrevChk = e.lastGroupKeyOfPrevChk[:length] } else { e.lastGroupKeyOfPrevChk = make([]byte, length) } copy(e.lastGroupKeyOfPrevChk, e.lastGroupKey) if bytes.Equal(e.firstGroupKey, e.lastGroupKey) { e.groupOffset = append(e.groupOffset, numRows) e.groupCount = 1 return isFirstGroupSameAsPrev, nil } if cap(e.sameGroup) < numRows { e.sameGroup = make([]bool, 0, numRows) } e.sameGroup = append(e.sameGroup, false) for i := 1; i < numRows; i++ { e.sameGroup = append(e.sameGroup, true) } for _, item := range e.GroupByItems { err = e.evalGroupItemsAndResolveGroups(item, e.vecEnabled, chk, numRows) if err != nil { return false, err } } for i := 1; i < numRows; i++ { if !e.sameGroup[i] { e.groupOffset = append(e.groupOffset, i) } } e.groupOffset = append(e.groupOffset, numRows) e.groupCount = len(e.groupOffset) return isFirstGroupSameAsPrev, nil } func (e *VecGroupChecker) getFirstAndLastRowDatum( item expression.Expression, chk *chunk.Chunk, numRows int) (err error) { var firstRowDatum, lastRowDatum types.Datum tp := item.GetType(e.ctx) eType := tp.EvalType() switch eType { case types.ETInt: firstRowVal, firstRowIsNull, err := item.EvalInt(e.ctx, chk.GetRow(0)) if err != nil { return err } lastRowVal, lastRowIsNull, err := item.EvalInt(e.ctx, chk.GetRow(numRows-1)) if err != nil { return err } if !firstRowIsNull { firstRowDatum.SetInt64(firstRowVal) } else { firstRowDatum.SetNull() } if !lastRowIsNull { lastRowDatum.SetInt64(lastRowVal) } else { lastRowDatum.SetNull() } case types.ETReal: firstRowVal, firstRowIsNull, err := item.EvalReal(e.ctx, chk.GetRow(0)) if err != nil { return err } lastRowVal, lastRowIsNull, err := item.EvalReal(e.ctx, chk.GetRow(numRows-1)) if err != nil { return err } if !firstRowIsNull { firstRowDatum.SetFloat64(firstRowVal) } else { firstRowDatum.SetNull() } if !lastRowIsNull { lastRowDatum.SetFloat64(lastRowVal) } else { lastRowDatum.SetNull() } case types.ETDecimal: firstRowVal, firstRowIsNull, err := item.EvalDecimal(e.ctx, chk.GetRow(0)) if err != nil { return err } lastRowVal, lastRowIsNull, err := item.EvalDecimal(e.ctx, chk.GetRow(numRows-1)) if err != nil { return err } if !firstRowIsNull { // make a copy to avoid DATA RACE firstDatum := types.MyDecimal{} err := firstDatum.FromString(firstRowVal.ToString()) if err != nil { return err } firstRowDatum.SetMysqlDecimal(&firstDatum) } else { firstRowDatum.SetNull() } if !lastRowIsNull { // make a copy to avoid DATA RACE lastDatum := types.MyDecimal{} err := lastDatum.FromString(lastRowVal.ToString()) if err != nil { return err } lastRowDatum.SetMysqlDecimal(&lastDatum) } else { lastRowDatum.SetNull() } case types.ETDatetime, types.ETTimestamp: firstRowVal, firstRowIsNull, err := item.EvalTime(e.ctx, chk.GetRow(0)) if err != nil { return err } lastRowVal, lastRowIsNull, err := item.EvalTime(e.ctx, chk.GetRow(numRows-1)) if err != nil { return err } if !firstRowIsNull { firstRowDatum.SetMysqlTime(firstRowVal) } else { firstRowDatum.SetNull() } if !lastRowIsNull { lastRowDatum.SetMysqlTime(lastRowVal) } else { lastRowDatum.SetNull() } case types.ETDuration: firstRowVal, firstRowIsNull, err := item.EvalDuration(e.ctx, chk.GetRow(0)) if err != nil { return err } lastRowVal, lastRowIsNull, err := item.EvalDuration(e.ctx, chk.GetRow(numRows-1)) if err != nil { return err } if !firstRowIsNull { firstRowDatum.SetMysqlDuration(firstRowVal) } else { firstRowDatum.SetNull() } if !lastRowIsNull { lastRowDatum.SetMysqlDuration(lastRowVal) } else { lastRowDatum.SetNull() } case types.ETJson: firstRowVal, firstRowIsNull, err := item.EvalJSON(e.ctx, chk.GetRow(0)) if err != nil { return err } lastRowVal, lastRowIsNull, err := item.EvalJSON(e.ctx, chk.GetRow(numRows-1)) if err != nil { return err } if !firstRowIsNull { // make a copy to avoid DATA RACE firstRowDatum.SetMysqlJSON(firstRowVal.Copy()) } else { firstRowDatum.SetNull() } if !lastRowIsNull { // make a copy to avoid DATA RACE lastRowDatum.SetMysqlJSON(lastRowVal.Copy()) } else { lastRowDatum.SetNull() } case types.ETVectorFloat32: firstRowVal, firstRowIsNull, err := item.EvalVectorFloat32(e.ctx, chk.GetRow(0)) if err != nil { return err } lastRowVal, lastRowIsNull, err := item.EvalVectorFloat32(e.ctx, chk.GetRow(numRows-1)) if err != nil { return err } if !firstRowIsNull { // make a copy to avoid DATA RACE firstRowDatum.SetVectorFloat32(firstRowVal.Clone()) } else { firstRowDatum.SetNull() } if !lastRowIsNull { // make a copy to avoid DATA RACE lastRowDatum.SetVectorFloat32(lastRowVal.Clone()) } else { lastRowDatum.SetNull() } case types.ETString: firstRowVal, firstRowIsNull, err := item.EvalString(e.ctx, chk.GetRow(0)) if err != nil { return err } lastRowVal, lastRowIsNull, err := item.EvalString(e.ctx, chk.GetRow(numRows-1)) if err != nil { return err } if !firstRowIsNull { // make a copy to avoid DATA RACE firstDatum := string([]byte(firstRowVal)) firstRowDatum.SetString(firstDatum, tp.GetCollate()) } else { firstRowDatum.SetNull() } if !lastRowIsNull { // make a copy to avoid DATA RACE lastDatum := string([]byte(lastRowVal)) lastRowDatum.SetString(lastDatum, tp.GetCollate()) } else { lastRowDatum.SetNull() } default: err = errors.Errorf("unsupported type %s during evaluation", eType) return err } e.firstRowDatums = append(e.firstRowDatums, firstRowDatum) e.lastRowDatums = append(e.lastRowDatums, lastRowDatum) return err } // evalGroupItemsAndResolveGroups evaluates the chunk according to the expression item. // And resolve the rows into groups according to the evaluation results func (e *VecGroupChecker) evalGroupItemsAndResolveGroups( item expression.Expression, vecEnabled bool, chk *chunk.Chunk, numRows int) (err error) { tp := item.GetType(e.ctx) eType := tp.EvalType() if e.allocateBuffer == nil { e.allocateBuffer = expression.GetColumn } if e.releaseBuffer == nil { e.releaseBuffer = expression.PutColumn } col, err := e.allocateBuffer(eType, numRows) if err != nil { return err } defer e.releaseBuffer(col) err = expression.EvalExpr(e.ctx, vecEnabled, item, eType, chk, col) if err != nil { return err } previousIsNull := col.IsNull(0) switch eType { case types.ETInt: vals := col.Int64s() for i := 1; i < numRows; i++ { isNull := col.IsNull(i) if e.sameGroup[i] { switch { case !previousIsNull && !isNull: if vals[i] != vals[i-1] { e.sameGroup[i] = false } case isNull != previousIsNull: e.sameGroup[i] = false } } previousIsNull = isNull } case types.ETReal: vals := col.Float64s() for i := 1; i < numRows; i++ { isNull := col.IsNull(i) if e.sameGroup[i] { switch { case !previousIsNull && !isNull: if vals[i] != vals[i-1] { e.sameGroup[i] = false } case isNull != previousIsNull: e.sameGroup[i] = false } } previousIsNull = isNull } case types.ETDecimal: vals := col.Decimals() for i := 1; i < numRows; i++ { isNull := col.IsNull(i) if e.sameGroup[i] { switch { case !previousIsNull && !isNull: if vals[i].Compare(&vals[i-1]) != 0 { e.sameGroup[i] = false } case isNull != previousIsNull: e.sameGroup[i] = false } } previousIsNull = isNull } case types.ETDatetime, types.ETTimestamp: vals := col.Times() for i := 1; i < numRows; i++ { isNull := col.IsNull(i) if e.sameGroup[i] { switch { case !previousIsNull && !isNull: if vals[i].Compare(vals[i-1]) != 0 { e.sameGroup[i] = false } case isNull != previousIsNull: e.sameGroup[i] = false } } previousIsNull = isNull } case types.ETDuration: vals := col.GoDurations() for i := 1; i < numRows; i++ { isNull := col.IsNull(i) if e.sameGroup[i] { switch { case !previousIsNull && !isNull: if vals[i] != vals[i-1] { e.sameGroup[i] = false } case isNull != previousIsNull: e.sameGroup[i] = false } } previousIsNull = isNull } case types.ETJson: var previousKey, key types.BinaryJSON if !previousIsNull { previousKey = col.GetJSON(0) } for i := 1; i < numRows; i++ { isNull := col.IsNull(i) if !isNull { key = col.GetJSON(i) } if e.sameGroup[i] { if isNull == previousIsNull { if !isNull && types.CompareBinaryJSON(previousKey, key) != 0 { e.sameGroup[i] = false } } else { e.sameGroup[i] = false } } if !isNull { previousKey = key } previousIsNull = isNull } case types.ETVectorFloat32: var previousKey, key types.VectorFloat32 if !previousIsNull { previousKey = col.GetVectorFloat32(0) } for i := 1; i < numRows; i++ { isNull := col.IsNull(i) if !isNull { key = col.GetVectorFloat32(i) } if e.sameGroup[i] { if isNull == previousIsNull { if !isNull && previousKey.Compare(key) != 0 { e.sameGroup[i] = false } } else { e.sameGroup[i] = false } } if !isNull { previousKey = key } previousIsNull = isNull } case types.ETString: previousKey := codec.ConvertByCollationStr(col.GetString(0), tp) for i := 1; i < numRows; i++ { key := codec.ConvertByCollationStr(col.GetString(i), tp) isNull := col.IsNull(i) if e.sameGroup[i] { if isNull != previousIsNull || previousKey != key { e.sameGroup[i] = false } } previousKey = key previousIsNull = isNull } default: err = errors.Errorf("unsupported type %s during evaluation", eType) } if err != nil { return err } return err } // GetNextGroup returns the begin and end position of the next group. func (e *VecGroupChecker) GetNextGroup() (begin, end int) { if e.nextGroupID == 0 { begin = 0 } else { begin = e.groupOffset[e.nextGroupID-1] } end = e.groupOffset[e.nextGroupID] e.nextGroupID++ return begin, end } // IsExhausted returns true if there is no more group to check. func (e *VecGroupChecker) IsExhausted() bool { return e.nextGroupID >= e.groupCount } // Reset resets the group checker. func (e *VecGroupChecker) Reset() { if e.groupOffset != nil { e.groupOffset = e.groupOffset[:0] e.groupCount = 0 } if e.sameGroup != nil { e.sameGroup = e.sameGroup[:0] } if e.firstGroupKey != nil { e.firstGroupKey = e.firstGroupKey[:0] } if e.lastGroupKey != nil { e.lastGroupKey = e.lastGroupKey[:0] } if e.firstRowDatums != nil { e.firstRowDatums = e.firstRowDatums[:0] } if e.lastRowDatums != nil { e.lastRowDatums = e.lastRowDatums[:0] } } // GroupCount returns the number of groups. func (e *VecGroupChecker) GroupCount() int { return e.groupCount }