[feature-wip](nereids) Implement using join (#15311)

This commit is contained in:
Kikyou1997
2022-12-28 19:22:20 +08:00
committed by GitHub
parent 75aa00d3d0
commit 95e6553d90
8 changed files with 382 additions and 74 deletions

View File

@ -215,22 +215,6 @@ public class CascadesContext {
return this;
}
public void addToTable(Table table) {
tables.add(table);
}
public void lockTableOnRead() {
for (Table t : tables) {
t.readLock();
}
}
public void releaseTableReadLock() {
for (Table t : tables) {
t.readUnlock();
}
}
/**
* Extract tables.
*/

View File

@ -199,6 +199,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
import org.apache.doris.nereids.trees.plans.logical.UsingJoin;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.TinyIntType;
@ -1249,28 +1250,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
} else {
joinType = JoinType.CROSS_JOIN;
}
// TODO: natural join, lateral join, using join, union join
JoinCriteriaContext joinCriteria = join.joinCriteria();
Optional<Expression> condition = Optional.empty();
if (joinCriteria != null) {
if (joinCriteria.booleanExpression() != null) {
condition = Optional.ofNullable(getExpression(joinCriteria.booleanExpression()));
}
if (joinCriteria.USING() != null) {
List<UnboundSlot> ids =
visitIdentifierList(joinCriteria.identifierList())
.stream().map(UnboundSlot::quoted).collect(
Collectors.toList());
return new LogicalJoin(JoinType.USING_JOIN, ids, last, plan(join.relationPrimary()));
}
} else {
// keep same with original planner, allow cross/inner join
if (!joinType.isInnerOrCrossJoin()) {
throw new ParseException("on mustn't be empty except for cross/inner join", join);
}
}
JoinHint joinHint = Optional.ofNullable(join.joinHint()).map(hintCtx -> {
String hint = typedVisit(join.joinHint());
if (JoinHint.JoinHintType.SHUFFLE.toString().equalsIgnoreCase(hint)) {
@ -1281,13 +1260,35 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
throw new ParseException("Invalid join hint: " + hint, hintCtx);
}
}).orElse(JoinHint.NONE);
last = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION,
condition.map(ExpressionUtils::extractConjunction)
.orElse(ExpressionUtils.EMPTY_CONDITION),
joinHint,
last,
plan(join.relationPrimary()));
// TODO: natural join, lateral join, using join, union join
JoinCriteriaContext joinCriteria = join.joinCriteria();
Optional<Expression> condition = Optional.empty();
List<UnboundSlot> ids = null;
if (joinCriteria != null) {
if (joinCriteria.booleanExpression() != null) {
condition = Optional.ofNullable(getExpression(joinCriteria.booleanExpression()));
} else if (joinCriteria.USING() != null) {
ids = visitIdentifierList(joinCriteria.identifierList())
.stream().map(UnboundSlot::quoted).collect(
Collectors.toList());
}
} else {
// keep same with original planner, allow cross/inner join
if (!joinType.isInnerOrCrossJoin()) {
throw new ParseException("on mustn't be empty except for cross/inner join", join);
}
}
if (ids == null) {
last = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION,
condition.map(ExpressionUtils::extractConjunction)
.orElse(ExpressionUtils.EMPTY_CONDITION),
joinHint,
last,
plan(join.relationPrimary()));
} else {
last = new UsingJoin(joinType, last,
plan(join.relationPrimary()), Collections.emptyList(), ids, joinHint);
}
}
return last;
}

View File

@ -66,6 +66,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.UsingJoin;
import org.apache.doris.planner.PlannerContext;
import com.google.common.base.Preconditions;
@ -76,6 +77,7 @@ import org.apache.commons.lang.StringUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -127,41 +129,63 @@ public class BindSlotReference implements AnalysisRuleFactory {
return new LogicalFilter<>(boundConjuncts, filter.child());
})
),
RuleType.BINDING_JOIN_SLOT.build(
logicalJoin().when(Plan::canBind)
.whenNot(j -> j.getJoinType().equals(JoinType.USING_JOIN)).thenApply(ctx -> {
LogicalJoin<GroupPlan, GroupPlan> join = ctx.root;
List<Expression> cond = join.getOtherJoinConjuncts().stream()
.map(expr -> bind(expr, join.children(), join, ctx.cascadesContext))
.collect(Collectors.toList());
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts().stream()
.map(expr -> bind(expr, join.children(), join, ctx.cascadesContext))
.collect(Collectors.toList());
return new LogicalJoin<>(join.getJoinType(),
hashJoinConjuncts, cond, join.getHint(), join.left(), join.right());
})
),
RuleType.BINDING_USING_JOIN_SLOT.build(
logicalJoin().when(j -> j.getJoinType().equals(JoinType.USING_JOIN)).thenApply(ctx -> {
LogicalJoin<GroupPlan, GroupPlan> join = ctx.root;
List<Expression> unboundSlots = join.getHashJoinConjuncts();
List<Expression> leftSlots = unboundSlots.stream()
.map(expr -> bind(expr, Collections.singletonList(join.left()),
join, ctx.cascadesContext))
.collect(Collectors.toList());
List<Expression> rightSlots = unboundSlots.stream()
.map(expr -> bind(expr, Collections.singletonList(join.right()),
join, ctx.cascadesContext))
.collect(Collectors.toList());
usingJoin().thenApply(ctx -> {
UsingJoin<GroupPlan, GroupPlan> using = ctx.root;
LogicalJoin lj = new LogicalJoin(using.getJoinType() == JoinType.CROSS_JOIN
? JoinType.INNER_JOIN : using.getJoinType(),
using.getHashJoinConjuncts(),
using.getOtherJoinConjuncts(), using.getHint(), using.left(),
using.right());
List<Expression> unboundSlots = lj.getHashJoinConjuncts();
Set<String> slotNames = new HashSet<>();
List<Slot> leftOutput = new ArrayList<>(lj.left().getOutput());
// Suppose A JOIN B USING(name) JOIN C USING(name), [A JOIN B] is the left node, in this case,
// C should combine with table B on C.name=B.name. so we reverse the output to make sure that
// the most right slot is matched with priority.
Collections.reverse(leftOutput);
List<Expression> leftSlots = new ArrayList<>();
Scope scope = toScope(leftOutput.stream()
.filter(s -> !slotNames.contains(s.getName()))
.peek(s -> slotNames.add(s.getName())).collect(
Collectors.toList()));
for (Expression unboundSlot : unboundSlots) {
Expression expression = new SlotBinder(scope, lj, ctx.cascadesContext).bind(unboundSlot);
leftSlots.add(expression);
}
slotNames.clear();
scope = toScope(lj.right().getOutput().stream()
.filter(s -> !slotNames.contains(s.getName()))
.peek(s -> slotNames.add(s.getName())).collect(
Collectors.toList()));
List<Expression> rightSlots = new ArrayList<>();
for (Expression unboundSlot : unboundSlots) {
Expression expression = new SlotBinder(scope, lj, ctx.cascadesContext).bind(unboundSlot);
rightSlots.add(expression);
}
int size = leftSlots.size();
List<Expression> hashEqExpr = new ArrayList<>();
for (int i = 0; i < size; i++) {
hashEqExpr.add(new EqualTo(leftSlots.get(i), rightSlots.get(i)));
}
return new LogicalJoin(JoinType.INNER_JOIN, hashEqExpr,
join.getOtherJoinConjuncts(), join.getHint(), join.left(), join.right());
return lj.withHashJoinConjuncts(hashEqExpr);
})
),
RuleType.BINDING_JOIN_SLOT.build(
logicalJoin().when(Plan::canBind)
.whenNot(j -> j.getJoinType().equals(JoinType.USING_JOIN)).thenApply(ctx -> {
LogicalJoin<GroupPlan, GroupPlan> join = ctx.root;
List<Expression> cond = join.getOtherJoinConjuncts().stream()
.map(expr -> bind(expr, join.children(), join, ctx.cascadesContext))
.collect(Collectors.toList());
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts().stream()
.map(expr -> bind(expr, join.children(), join, ctx.cascadesContext))
.collect(Collectors.toList());
return new LogicalJoin<>(join.getJoinType(),
hashJoinConjuncts, cond, join.getHint(), join.left(), join.right());
})
),
RuleType.BINDING_AGGREGATE_SLOT.build(
logicalAggregate().when(Plan::canBind).thenApply(ctx -> {
LogicalAggregate<GroupPlan> agg = ctx.root;

View File

@ -52,6 +52,7 @@ public enum PlanType {
LOGICAL_UNION,
LOGICAL_EXCEPT,
LOGICAL_INTERSECT,
LOGICAL_USING_JOIN,
GROUP_PLAN,
// physical plan

View File

@ -0,0 +1,169 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import org.apache.commons.collections.CollectionUtils;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* select col1 from t1 join t2 using(col1);
*/
public class UsingJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan>
extends LogicalBinary<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> implements Join {
private final JoinType joinType;
private final ImmutableList<Expression> otherJoinConjuncts;
private final ImmutableList<Expression> hashJoinConjuncts;
private final JoinHint hint;
public UsingJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild,
List<Expression> expressions, List<Expression> hashJoinConjuncts,
JoinHint hint) {
this(joinType, leftChild, rightChild, expressions,
hashJoinConjuncts, Optional.empty(), Optional.empty(), hint);
}
public UsingJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild,
List<Expression> expressions, List<Expression> hashJoinConjuncts, Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
JoinHint hint) {
super(PlanType.LOGICAL_USING_JOIN, groupExpression, logicalProperties, leftChild, rightChild);
this.joinType = joinType;
this.otherJoinConjuncts = ImmutableList.copyOf(expressions);
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.hint = hint;
}
@Override
public List<Slot> computeOutput() {
List<Slot> newLeftOutput = left().getOutput().stream().map(o -> o.withNullable(true))
.collect(Collectors.toList());
List<Slot> newRightOutput = right().getOutput().stream().map(o -> o.withNullable(true))
.collect(Collectors.toList());
switch (joinType) {
case LEFT_SEMI_JOIN:
case LEFT_ANTI_JOIN:
return ImmutableList.copyOf(left().getOutput());
case RIGHT_SEMI_JOIN:
case RIGHT_ANTI_JOIN:
return ImmutableList.copyOf(right().getOutput());
case LEFT_OUTER_JOIN:
return ImmutableList.<Slot>builder()
.addAll(left().getOutput())
.addAll(newRightOutput)
.build();
case RIGHT_OUTER_JOIN:
return ImmutableList.<Slot>builder()
.addAll(newLeftOutput)
.addAll(right().getOutput())
.build();
case FULL_OUTER_JOIN:
return ImmutableList.<Slot>builder()
.addAll(newLeftOutput)
.addAll(newRightOutput)
.build();
default:
return ImmutableList.<Slot>builder()
.addAll(left().getOutput())
.addAll(right().getOutput())
.build();
}
}
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new UsingJoin(joinType, child(0), child(1), otherJoinConjuncts,
hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint);
}
@Override
public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new UsingJoin(joinType, child(0), child(1), otherJoinConjuncts,
hashJoinConjuncts, groupExpression, logicalProperties, hint);
}
@Override
public Plan withChildren(List<Plan> children) {
return new UsingJoin(joinType, children.get(0), children.get(1), otherJoinConjuncts,
hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint);
}
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visit(this, context);
}
@Override
public List<? extends Expression> getExpressions() {
return new Builder<Expression>()
.addAll(hashJoinConjuncts)
.addAll(otherJoinConjuncts)
.build();
}
public JoinType getJoinType() {
return joinType;
}
public List<Expression> getOtherJoinConjuncts() {
return otherJoinConjuncts;
}
public List<Expression> getHashJoinConjuncts() {
return hashJoinConjuncts;
}
public JoinHint getHint() {
return hint;
}
@Override
public Optional<Expression> getOnClauseCondition() {
return ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts);
}
@Override
public boolean hasJoinHint() {
return hint != null;
}
@Override
public boolean hasJoinCondition() {
return !CollectionUtils.isEmpty(hashJoinConjuncts);
}
}

View File

@ -0,0 +1,28 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !join1 --
\N \N \N \N dd 33
\N \N cc 22 cc 23
\N \N ee 42 \N \N
bb 11 bb 12 bb 13
-- !join2 --
bb 12 bb 13
cc 22 cc 23
-- !join3 --
bb 12 bb 13
cc 22 cc 23
ee 42 \N \N
-- !join4 --
\N \N dd 33
bb 12 bb 13
cc 22 cc 23
ee 42 \N \N
-- !join7 --
\N \N bb 2
\N \N cc 2
\N \N ee 2
bb 11 \N \N

View File

@ -0,0 +1,101 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("nereids_test_join3", "query,p0") {
sql "SET enable_vectorized_engine=true"
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"
def DBname = "regression_test_join3"
sql "DROP DATABASE IF EXISTS ${DBname}"
sql "CREATE DATABASE IF NOT EXISTS ${DBname}"
sql "use ${DBname}"
def tbName1 = "t1"
def tbName2 = "t2"
def tbName3 = "t3"
sql """CREATE TABLE IF NOT EXISTS ${tbName1} (name varchar(255), n INTEGER) DISTRIBUTED BY HASH(name) properties("replication_num" = "1");"""
sql """CREATE TABLE IF NOT EXISTS ${tbName2} (name varchar(255), n INTEGER) DISTRIBUTED BY HASH(name) properties("replication_num" = "1");"""
sql """CREATE TABLE IF NOT EXISTS ${tbName3} (name varchar(255), n INTEGER) DISTRIBUTED BY HASH(name) properties("replication_num" = "1");"""
sql "INSERT INTO ${tbName1} VALUES ( 'bb', 11 );"
sql "INSERT INTO ${tbName2} VALUES ( 'bb', 12 );"
sql "INSERT INTO ${tbName2} VALUES ( 'cc', 22 );"
sql "INSERT INTO ${tbName2} VALUES ( 'ee', 42 );"
sql "INSERT INTO ${tbName3} VALUES ( 'bb', 13 );"
sql "INSERT INTO ${tbName3} VALUES ( 'cc', 23 );"
sql "INSERT INTO ${tbName3} VALUES ( 'dd', 33 );"
qt_join1 """
SELECT * FROM ${tbName1} FULL JOIN ${tbName2} USING (name) FULL JOIN ${tbName3} USING (name) ORDER BY 1,2,3,4,5,6;
"""
qt_join2 """
SELECT * FROM
(SELECT * FROM ${tbName2}) as s2
INNER JOIN
(SELECT * FROM ${tbName3}) s3
USING (name)
ORDER BY 1,2,3,4;
"""
qt_join3 """
SELECT * FROM
(SELECT * FROM ${tbName2}) as s2
LEFT JOIN
(SELECT * FROM ${tbName3}) s3
USING (name)
ORDER BY 1,2,3,4;
"""
qt_join4 """
SELECT * FROM
(SELECT * FROM ${tbName2}) as s2
FULL JOIN
(SELECT * FROM ${tbName3}) s3
USING (name)
ORDER BY 1,2,3,4;
"""
// wait fix
// qt_join5 """
// SELECT * FROM
// (SELECT name, n as s2_n, 2 as s2_2 FROM ${tbName2}) as s2
// NATURAL INNER JOIN
// (SELECT name, n as s3_n, 3 as s3_2 FROM ${tbName3}) s3
// ORDER BY 1,2,3,4;
// """
// qt_join6 """
// SELECT * FROM
// (SELECT name, n as s1_n, 1 as s1_1 FROM ${tbName1}) as s1
// NATURAL INNER JOIN
// (SELECT name, n as s2_n, 2 as s2_2 FROM ${tbName2}) as s2
// NATURAL INNER JOIN
// (SELECT name, n as s3_n, 3 as s3_2 FROM ${tbName3}) s3;
// """
qt_join7 """
SELECT * FROM
(SELECT name, n as s1_n FROM ${tbName1}) as s1
FULL JOIN
(SELECT name, 2 as s2_n FROM ${tbName2}) as s2
ON (s1_n = s2_n)
ORDER BY 1,2,3,4;
"""
// sql "DROP DATABASE IF EXISTS ${DBname}"
}

View File

@ -66,11 +66,11 @@ suite("nereids_using_join") {
sql """INSERT INTO t2 VALUES('6', 3, 1)"""
sql """INSERT INTO t2 VALUES('7', 4, 1)"""
qt_sql """
order_qt_sql """
SELECT t1.col1 FROM t1 JOIN t2 USING (col1)
"""
qt_sql """
order_qt_sql """
SELECT t1.col1 FROM t1 JOIN t2 USING (col1, col2)
"""