Files
tidb/executor/merge_join.go

522 lines
13 KiB
Go

// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package executor
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/types"
)
// MergeJoinExec implements the merge join algorithm.
// This operator assumes that two iterators of both sides
// will provide required order on join condition:
// 1. For equal-join, one of the join key from each side
// matches the order given.
// 2. For other cases its preferred not to use SMJ and operator
// will throw error.
type MergeJoinExec struct {
// Left is always the driver side
ctx context.Context
stmtCtx *variable.StatementContext
leftJoinKeys []*expression.Column
rightJoinKeys []*expression.Column
prepared bool
leftFilter []expression.Expression
otherFilter []expression.Expression
schema *expression.Schema
preserveLeft bool // To preserve left side of the relation as in left outer join
cursor int
defaultValues []types.Datum
// Default for both side in case full join
defaultRightRow *Row
outputBuf []*Row
leftRowBlock *rowBlockIterator
rightRowBlock *rowBlockIterator
leftRows []*Row
rightRows []*Row
desc bool
flipSide bool
}
const rowBufferSize = 4096
type joinBuilder struct {
context context.Context
leftChild Executor
rightChild Executor
eqConditions []*expression.ScalarFunction
leftFilter []expression.Expression
rightFilter []expression.Expression
otherFilter []expression.Expression
schema *expression.Schema
joinType plan.JoinType
defaultValues []types.Datum
}
func (b *joinBuilder) Context(context context.Context) *joinBuilder {
b.context = context
return b
}
func (b *joinBuilder) EqualConditions(conds []*expression.ScalarFunction) *joinBuilder {
b.eqConditions = conds
return b
}
func (b *joinBuilder) LeftChild(exec Executor) *joinBuilder {
b.leftChild = exec
return b
}
func (b *joinBuilder) RightChild(exec Executor) *joinBuilder {
b.rightChild = exec
return b
}
func (b *joinBuilder) LeftFilter(expr []expression.Expression) *joinBuilder {
b.leftFilter = expr
return b
}
func (b *joinBuilder) RightFilter(expr []expression.Expression) *joinBuilder {
b.rightFilter = expr
return b
}
func (b *joinBuilder) OtherFilter(expr []expression.Expression) *joinBuilder {
b.otherFilter = expr
return b
}
func (b *joinBuilder) Schema(schema *expression.Schema) *joinBuilder {
b.schema = schema
return b
}
func (b *joinBuilder) JoinType(joinType plan.JoinType) *joinBuilder {
b.joinType = joinType
return b
}
func (b *joinBuilder) DefaultVals(defaultValues []types.Datum) *joinBuilder {
b.defaultValues = defaultValues
return b
}
func (b *joinBuilder) BuildMergeJoin(assumeSortedDesc bool) (*MergeJoinExec, error) {
var leftJoinKeys, rightJoinKeys []*expression.Column
for _, eqCond := range b.eqConditions {
if len(eqCond.GetArgs()) != 2 {
return nil, errors.Annotate(ErrBuildExecutor, "invalid join key for equal condition")
}
lKey, ok := eqCond.GetArgs()[0].(*expression.Column)
if !ok {
return nil, errors.Annotate(ErrBuildExecutor, "left side of join key must be column for merge join")
}
rKey, ok := eqCond.GetArgs()[1].(*expression.Column)
if !ok {
return nil, errors.Annotate(ErrBuildExecutor, "right side of join key must be column for merge join")
}
leftJoinKeys = append(leftJoinKeys, lKey)
rightJoinKeys = append(rightJoinKeys, rKey)
}
leftRowBlock := &rowBlockIterator{
ctx: b.context,
reader: b.leftChild,
filter: b.leftFilter,
joinKeys: leftJoinKeys,
}
rightRowBlock := &rowBlockIterator{
ctx: b.context,
reader: b.rightChild,
filter: b.rightFilter,
joinKeys: rightJoinKeys,
}
exec := &MergeJoinExec{
ctx: b.context,
leftJoinKeys: leftJoinKeys,
rightJoinKeys: rightJoinKeys,
leftRowBlock: leftRowBlock,
rightRowBlock: rightRowBlock,
otherFilter: b.otherFilter,
schema: b.schema,
desc: assumeSortedDesc,
}
switch b.joinType {
case plan.LeftOuterJoin:
exec.leftRowBlock.filter = nil
exec.leftFilter = b.leftFilter
exec.preserveLeft = true
exec.defaultRightRow = &Row{Data: b.defaultValues}
case plan.RightOuterJoin:
exec.leftRowBlock = rightRowBlock
exec.rightRowBlock = leftRowBlock
exec.leftRowBlock.filter = nil
exec.leftFilter = b.leftFilter
exec.preserveLeft = true
exec.defaultRightRow = &Row{Data: b.defaultValues}
exec.flipSide = true
exec.leftJoinKeys = rightJoinKeys
exec.rightJoinKeys = leftJoinKeys
case plan.InnerJoin:
default:
return nil, errors.Annotate(ErrBuildExecutor, "unknown join type")
}
return exec, nil
}
// rowBlockIterator represents a row block with the same join keys
type rowBlockIterator struct {
stmtCtx *variable.StatementContext
ctx context.Context
reader Executor
filter []expression.Expression
joinKeys []*expression.Column
peekedRow *Row
rowCache []*Row
}
func (rb *rowBlockIterator) init() error {
if rb.reader == nil || rb.joinKeys == nil || len(rb.joinKeys) == 0 || rb.ctx == nil {
return errors.Errorf("Invalid arguments: Empty arguments detected.")
}
rb.stmtCtx = rb.ctx.GetSessionVars().StmtCtx
var err error
rb.peekedRow, err = rb.nextRow()
if err != nil {
return errors.Trace(err)
}
rb.rowCache = make([]*Row, 0, rowBufferSize)
return nil
}
func (rb *rowBlockIterator) nextRow() (*Row, error) {
for {
row, err := rb.reader.Next()
if err != nil {
return nil, errors.Trace(err)
}
if row == nil {
return nil, nil
}
if rb.filter != nil {
matched, err := expression.EvalBool(rb.filter, row.Data, rb.ctx)
if err != nil {
return nil, errors.Trace(err)
}
if !matched {
continue
}
}
return row, nil
}
}
func (rb *rowBlockIterator) nextBlock() ([]*Row, error) {
var err error
peekedRow := rb.peekedRow
var curRow *Row
if peekedRow == nil {
return nil, nil
}
rowCache := rb.rowCache[0:0:rowBufferSize]
rowCache = append(rowCache, peekedRow)
for {
curRow, err = rb.nextRow()
if err != nil {
return nil, errors.Trace(err)
}
if curRow == nil {
rb.peekedRow = nil
return rowCache, nil
}
compareResult, err := compareKeys(rb.stmtCtx, curRow, rb.joinKeys, rb.peekedRow, rb.joinKeys)
if err != nil {
return nil, errors.Trace(err)
}
if compareResult == 0 {
rowCache = append(rowCache, curRow)
} else {
rb.peekedRow = curRow
return rowCache, nil
}
}
}
// Close implements the Executor Close interface.
func (e *MergeJoinExec) Close() error {
e.outputBuf = nil
lErr := e.leftRowBlock.reader.Close()
if lErr != nil {
e.rightRowBlock.reader.Close()
return errors.Trace(lErr)
}
rErr := e.rightRowBlock.reader.Close()
if rErr != nil {
return errors.Trace(rErr)
}
return nil
}
// Open implements the Executor Open interface.
func (e *MergeJoinExec) Open() error {
e.prepared = false
e.cursor = 0
e.outputBuf = nil
err := e.leftRowBlock.reader.Open()
if err != nil {
return errors.Trace(err)
}
return errors.Trace(e.rightRowBlock.reader.Open())
}
// Schema implements the Executor Schema interface.
func (e *MergeJoinExec) Schema() *expression.Schema {
return e.schema
}
func compareKeys(stmtCtx *variable.StatementContext,
leftRow *Row, leftKeys []*expression.Column,
rightRow *Row, rightKeys []*expression.Column) (int, error) {
for i, leftKey := range leftKeys {
lVal, err := leftKey.Eval(leftRow.Data)
if err != nil {
return 0, errors.Trace(err)
}
rVal, err := rightKeys[i].Eval(rightRow.Data)
if err != nil {
return 0, errors.Trace(err)
}
ret, err := lVal.CompareDatum(stmtCtx, rVal)
if err != nil {
return 0, errors.Trace(err)
}
if ret != 0 {
return ret, nil
}
}
return 0, nil
}
func (e *MergeJoinExec) outputJoinRow(leftRow *Row, rightRow *Row) {
var joinedRow *Row
if e.flipSide {
joinedRow = makeJoinRow(rightRow, leftRow)
} else {
joinedRow = makeJoinRow(leftRow, rightRow)
}
e.outputBuf = append(e.outputBuf, joinedRow)
}
func (e *MergeJoinExec) outputFilteredJoinRow(leftRow *Row, rightRow *Row) error {
var joinedRow *Row
if e.flipSide {
joinedRow = makeJoinRow(rightRow, leftRow)
} else {
joinedRow = makeJoinRow(leftRow, rightRow)
}
if e.otherFilter != nil {
matched, err := expression.EvalBool(e.otherFilter, joinedRow.Data, e.ctx)
if err != nil {
return errors.Trace(err)
}
if !matched {
return nil
}
}
e.outputBuf = append(e.outputBuf, joinedRow)
return nil
}
func (e *MergeJoinExec) tryOutputLeftRows() error {
if e.preserveLeft {
for _, lRow := range e.leftRows {
e.outputJoinRow(lRow, e.defaultRightRow)
}
}
return nil
}
func (e *MergeJoinExec) computeCrossProduct() error {
var err error
for _, lRow := range e.leftRows {
// make up for outer join since we ignored single table conditions previously
if e.leftFilter != nil {
matched, err := expression.EvalBool(e.leftFilter, lRow.Data, e.ctx)
if err != nil {
return errors.Trace(err)
}
if !matched {
// as all right join converted to left, we only output left side if no match and continue
if e.preserveLeft {
e.outputJoinRow(lRow, e.defaultRightRow)
}
continue
}
}
// Do the real cross product calculation
initInnerLen := len(e.outputBuf)
for _, rRow := range e.rightRows {
err = e.outputFilteredJoinRow(lRow, rRow)
if err != nil {
return errors.Trace(err)
}
}
// Even if caught up for left filter
// no matching but it's outer join
if e.preserveLeft && initInnerLen == len(e.outputBuf) {
e.outputJoinRow(lRow, e.defaultRightRow)
}
}
return nil
}
func (e *MergeJoinExec) computeJoin() (bool, error) {
e.outputBuf = e.outputBuf[0:0:rowBufferSize]
for {
var compareResult int
var err error
if e.leftRows == nil || e.rightRows == nil {
if e.leftRows != nil && e.rightRows == nil && e.preserveLeft {
// left remains and left outer join
// -1 will make loop continue for left
compareResult = -1
} else {
// inner join or left is nil
return false, nil
}
} else {
// no nil for either side, compare by first elements in row buffer since its guaranteed
compareResult, err = compareKeys(e.stmtCtx, e.leftRows[0], e.leftJoinKeys, e.rightRows[0], e.rightJoinKeys)
if err != nil {
return false, errors.Trace(err)
}
if e.desc {
compareResult = -compareResult
}
}
// Before moving on, in case of outer join, output the side of the row
if compareResult > 0 {
e.rightRows, err = e.rightRowBlock.nextBlock()
if err != nil {
return false, errors.Trace(err)
}
} else if compareResult < 0 {
initLen := len(e.outputBuf)
err := e.tryOutputLeftRows()
if err != nil {
return false, errors.Trace(err)
}
e.leftRows, err = e.leftRowBlock.nextBlock()
if err != nil {
return false, errors.Trace(err)
}
if initLen < len(e.outputBuf) {
return true, nil
}
} else { // key matched, try join with other conditions
initLen := len(e.outputBuf)
// Compute cross product when both sides matches
err := e.computeCrossProduct()
if err != nil {
return false, errors.Trace(err)
}
e.leftRows, err = e.leftRowBlock.nextBlock()
if err != nil {
return false, errors.Trace(err)
}
e.rightRows, err = e.rightRowBlock.nextBlock()
if err != nil {
return false, errors.Trace(err)
}
if initLen < len(e.outputBuf) {
return true, nil
}
}
}
}
func (e *MergeJoinExec) prepare() error {
e.stmtCtx = e.ctx.GetSessionVars().StmtCtx
err := e.leftRowBlock.init()
if err != nil {
return errors.Trace(err)
}
err = e.rightRowBlock.init()
if err != nil {
return errors.Trace(err)
}
e.outputBuf = make([]*Row, 0, rowBufferSize)
e.leftRows, err = e.leftRowBlock.nextBlock()
if err != nil {
return errors.Trace(err)
}
e.rightRows, err = e.rightRowBlock.nextBlock()
if err != nil {
return errors.Trace(err)
}
e.prepared = true
return nil
}
// Next implements the Executor Next interface.
func (e *MergeJoinExec) Next() (*Row, error) {
var err error
var hasMore bool
if !e.prepared {
if err = e.prepare(); err != nil {
return nil, errors.Trace(err)
}
}
if e.cursor >= len(e.outputBuf) {
hasMore, err = e.computeJoin()
if err != nil {
return nil, errors.Trace(err)
}
if !hasMore {
return nil, nil
}
e.cursor = 0
}
row := e.outputBuf[e.cursor]
e.cursor++
return row, nil
}