From 95e6553d90e52c28bb9ee2a5e478d42b3df63b39 Mon Sep 17 00:00:00 2001 From: Kikyou1997 <33112463+Kikyou1997@users.noreply.github.com> Date: Wed, 28 Dec 2022 19:22:20 +0800 Subject: [PATCH] [feature-wip](nereids) Implement using join (#15311) --- .../apache/doris/nereids/CascadesContext.java | 16 -- .../nereids/parser/LogicalPlanBuilder.java | 59 +++--- .../rules/analysis/BindSlotReference.java | 78 +++++--- .../doris/nereids/trees/plans/PlanType.java | 1 + .../trees/plans/logical/UsingJoin.java | 169 ++++++++++++++++++ .../data/nereids_syntax_p0/test_join3.out | 28 +++ .../nereids_syntax_p0/test_join3.groovy | 101 +++++++++++ .../nereids_syntax_p0/using_join.groovy | 4 +- 8 files changed, 382 insertions(+), 74 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/UsingJoin.java create mode 100644 regression-test/data/nereids_syntax_p0/test_join3.out create mode 100644 regression-test/suites/nereids_syntax_p0/test_join3.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java index 26243a9b55..1a9af3126e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java @@ -215,22 +215,6 @@ public class CascadesContext { return this; } - public void addToTable(Table table) { - tables.add(table); - } - - public void lockTableOnRead() { - for (Table t : tables) { - t.readLock(); - } - } - - public void releaseTableReadLock() { - for (Table t : tables) { - t.readUnlock(); - } - } - /** * Extract tables. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index ee32a2a9ae..6ada02d636 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -199,6 +199,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias; import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; import org.apache.doris.nereids.trees.plans.logical.RelationUtil; +import org.apache.doris.nereids.trees.plans.logical.UsingJoin; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.TinyIntType; @@ -1249,28 +1250,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { } else { joinType = JoinType.CROSS_JOIN; } - - // TODO: natural join, lateral join, using join, union join - JoinCriteriaContext joinCriteria = join.joinCriteria(); - Optional condition = Optional.empty(); - if (joinCriteria != null) { - if (joinCriteria.booleanExpression() != null) { - condition = Optional.ofNullable(getExpression(joinCriteria.booleanExpression())); - } - if (joinCriteria.USING() != null) { - List ids = - visitIdentifierList(joinCriteria.identifierList()) - .stream().map(UnboundSlot::quoted).collect( - Collectors.toList()); - return new LogicalJoin(JoinType.USING_JOIN, ids, last, plan(join.relationPrimary())); - } - } else { - // keep same with original planner, allow cross/inner join - if (!joinType.isInnerOrCrossJoin()) { - throw new ParseException("on mustn't be empty except for cross/inner join", join); - } - } - JoinHint joinHint = Optional.ofNullable(join.joinHint()).map(hintCtx -> { String hint = typedVisit(join.joinHint()); if (JoinHint.JoinHintType.SHUFFLE.toString().equalsIgnoreCase(hint)) { @@ -1281,13 +1260,35 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { throw new ParseException("Invalid join hint: " + hint, hintCtx); } }).orElse(JoinHint.NONE); - - last = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, - condition.map(ExpressionUtils::extractConjunction) - .orElse(ExpressionUtils.EMPTY_CONDITION), - joinHint, - last, - plan(join.relationPrimary())); + // TODO: natural join, lateral join, using join, union join + JoinCriteriaContext joinCriteria = join.joinCriteria(); + Optional condition = Optional.empty(); + List ids = null; + if (joinCriteria != null) { + if (joinCriteria.booleanExpression() != null) { + condition = Optional.ofNullable(getExpression(joinCriteria.booleanExpression())); + } else if (joinCriteria.USING() != null) { + ids = visitIdentifierList(joinCriteria.identifierList()) + .stream().map(UnboundSlot::quoted).collect( + Collectors.toList()); + } + } else { + // keep same with original planner, allow cross/inner join + if (!joinType.isInnerOrCrossJoin()) { + throw new ParseException("on mustn't be empty except for cross/inner join", join); + } + } + if (ids == null) { + last = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, + condition.map(ExpressionUtils::extractConjunction) + .orElse(ExpressionUtils.EMPTY_CONDITION), + joinHint, + last, + plan(join.relationPrimary())); + } else { + last = new UsingJoin(joinType, last, + plan(join.relationPrimary()), Collections.emptyList(), ids, joinHint); + } } return last; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java index d438db12ea..003967757b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java @@ -66,6 +66,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.trees.plans.logical.UsingJoin; import org.apache.doris.planner.PlannerContext; import com.google.common.base.Preconditions; @@ -76,6 +77,7 @@ import org.apache.commons.lang.StringUtils; import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -127,41 +129,63 @@ public class BindSlotReference implements AnalysisRuleFactory { return new LogicalFilter<>(boundConjuncts, filter.child()); }) ), - RuleType.BINDING_JOIN_SLOT.build( - logicalJoin().when(Plan::canBind) - .whenNot(j -> j.getJoinType().equals(JoinType.USING_JOIN)).thenApply(ctx -> { - LogicalJoin join = ctx.root; - List cond = join.getOtherJoinConjuncts().stream() - .map(expr -> bind(expr, join.children(), join, ctx.cascadesContext)) - .collect(Collectors.toList()); - List hashJoinConjuncts = join.getHashJoinConjuncts().stream() - .map(expr -> bind(expr, join.children(), join, ctx.cascadesContext)) - .collect(Collectors.toList()); - return new LogicalJoin<>(join.getJoinType(), - hashJoinConjuncts, cond, join.getHint(), join.left(), join.right()); - }) - ), + RuleType.BINDING_USING_JOIN_SLOT.build( - logicalJoin().when(j -> j.getJoinType().equals(JoinType.USING_JOIN)).thenApply(ctx -> { - LogicalJoin join = ctx.root; - List unboundSlots = join.getHashJoinConjuncts(); - List leftSlots = unboundSlots.stream() - .map(expr -> bind(expr, Collections.singletonList(join.left()), - join, ctx.cascadesContext)) - .collect(Collectors.toList()); - List rightSlots = unboundSlots.stream() - .map(expr -> bind(expr, Collections.singletonList(join.right()), - join, ctx.cascadesContext)) - .collect(Collectors.toList()); + usingJoin().thenApply(ctx -> { + UsingJoin using = ctx.root; + LogicalJoin lj = new LogicalJoin(using.getJoinType() == JoinType.CROSS_JOIN + ? JoinType.INNER_JOIN : using.getJoinType(), + using.getHashJoinConjuncts(), + using.getOtherJoinConjuncts(), using.getHint(), using.left(), + using.right()); + List unboundSlots = lj.getHashJoinConjuncts(); + Set slotNames = new HashSet<>(); + List leftOutput = new ArrayList<>(lj.left().getOutput()); + // Suppose A JOIN B USING(name) JOIN C USING(name), [A JOIN B] is the left node, in this case, + // C should combine with table B on C.name=B.name. so we reverse the output to make sure that + // the most right slot is matched with priority. + Collections.reverse(leftOutput); + List leftSlots = new ArrayList<>(); + Scope scope = toScope(leftOutput.stream() + .filter(s -> !slotNames.contains(s.getName())) + .peek(s -> slotNames.add(s.getName())).collect( + Collectors.toList())); + for (Expression unboundSlot : unboundSlots) { + Expression expression = new SlotBinder(scope, lj, ctx.cascadesContext).bind(unboundSlot); + leftSlots.add(expression); + } + slotNames.clear(); + scope = toScope(lj.right().getOutput().stream() + .filter(s -> !slotNames.contains(s.getName())) + .peek(s -> slotNames.add(s.getName())).collect( + Collectors.toList())); + List rightSlots = new ArrayList<>(); + for (Expression unboundSlot : unboundSlots) { + Expression expression = new SlotBinder(scope, lj, ctx.cascadesContext).bind(unboundSlot); + rightSlots.add(expression); + } int size = leftSlots.size(); List hashEqExpr = new ArrayList<>(); for (int i = 0; i < size; i++) { hashEqExpr.add(new EqualTo(leftSlots.get(i), rightSlots.get(i))); } - return new LogicalJoin(JoinType.INNER_JOIN, hashEqExpr, - join.getOtherJoinConjuncts(), join.getHint(), join.left(), join.right()); + return lj.withHashJoinConjuncts(hashEqExpr); }) ), + RuleType.BINDING_JOIN_SLOT.build( + logicalJoin().when(Plan::canBind) + .whenNot(j -> j.getJoinType().equals(JoinType.USING_JOIN)).thenApply(ctx -> { + LogicalJoin join = ctx.root; + List cond = join.getOtherJoinConjuncts().stream() + .map(expr -> bind(expr, join.children(), join, ctx.cascadesContext)) + .collect(Collectors.toList()); + List hashJoinConjuncts = join.getHashJoinConjuncts().stream() + .map(expr -> bind(expr, join.children(), join, ctx.cascadesContext)) + .collect(Collectors.toList()); + return new LogicalJoin<>(join.getJoinType(), + hashJoinConjuncts, cond, join.getHint(), join.left(), join.right()); + }) + ), RuleType.BINDING_AGGREGATE_SLOT.build( logicalAggregate().when(Plan::canBind).thenApply(ctx -> { LogicalAggregate agg = ctx.root; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java index c1681a08bb..75933232db 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java @@ -52,6 +52,7 @@ public enum PlanType { LOGICAL_UNION, LOGICAL_EXCEPT, LOGICAL_INTERSECT, + LOGICAL_USING_JOIN, GROUP_PLAN, // physical plan diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/UsingJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/UsingJoin.java new file mode 100644 index 0000000000..8a9d7af752 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/UsingJoin.java @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans.logical; + +import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.properties.LogicalProperties; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.JoinHint; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.PlanType; +import org.apache.doris.nereids.trees.plans.algebra.Join; +import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; +import org.apache.commons.collections.CollectionUtils; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * select col1 from t1 join t2 using(col1); + */ +public class UsingJoin + extends LogicalBinary implements Join { + + private final JoinType joinType; + private final ImmutableList otherJoinConjuncts; + private final ImmutableList hashJoinConjuncts; + private final JoinHint hint; + + public UsingJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild, + List expressions, List hashJoinConjuncts, + JoinHint hint) { + this(joinType, leftChild, rightChild, expressions, + hashJoinConjuncts, Optional.empty(), Optional.empty(), hint); + } + + public UsingJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild, + List expressions, List hashJoinConjuncts, Optional groupExpression, + Optional logicalProperties, + JoinHint hint) { + super(PlanType.LOGICAL_USING_JOIN, groupExpression, logicalProperties, leftChild, rightChild); + this.joinType = joinType; + this.otherJoinConjuncts = ImmutableList.copyOf(expressions); + this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts); + this.hint = hint; + } + + @Override + public List computeOutput() { + + List newLeftOutput = left().getOutput().stream().map(o -> o.withNullable(true)) + .collect(Collectors.toList()); + + List newRightOutput = right().getOutput().stream().map(o -> o.withNullable(true)) + .collect(Collectors.toList()); + + switch (joinType) { + case LEFT_SEMI_JOIN: + case LEFT_ANTI_JOIN: + return ImmutableList.copyOf(left().getOutput()); + case RIGHT_SEMI_JOIN: + case RIGHT_ANTI_JOIN: + return ImmutableList.copyOf(right().getOutput()); + case LEFT_OUTER_JOIN: + return ImmutableList.builder() + .addAll(left().getOutput()) + .addAll(newRightOutput) + .build(); + case RIGHT_OUTER_JOIN: + return ImmutableList.builder() + .addAll(newLeftOutput) + .addAll(right().getOutput()) + .build(); + case FULL_OUTER_JOIN: + return ImmutableList.builder() + .addAll(newLeftOutput) + .addAll(newRightOutput) + .build(); + default: + return ImmutableList.builder() + .addAll(left().getOutput()) + .addAll(right().getOutput()) + .build(); + } + } + + @Override + public Plan withGroupExpression(Optional groupExpression) { + return new UsingJoin(joinType, child(0), child(1), otherJoinConjuncts, + hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint); + } + + @Override + public Plan withLogicalProperties(Optional logicalProperties) { + return new UsingJoin(joinType, child(0), child(1), otherJoinConjuncts, + hashJoinConjuncts, groupExpression, logicalProperties, hint); + } + + @Override + public Plan withChildren(List children) { + return new UsingJoin(joinType, children.get(0), children.get(1), otherJoinConjuncts, + hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint); + } + + @Override + public R accept(PlanVisitor visitor, C context) { + return visitor.visit(this, context); + } + + @Override + public List getExpressions() { + return new Builder() + .addAll(hashJoinConjuncts) + .addAll(otherJoinConjuncts) + .build(); + } + + public JoinType getJoinType() { + return joinType; + } + + public List getOtherJoinConjuncts() { + return otherJoinConjuncts; + } + + public List getHashJoinConjuncts() { + return hashJoinConjuncts; + } + + public JoinHint getHint() { + return hint; + } + + @Override + public Optional getOnClauseCondition() { + return ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts); + } + + @Override + public boolean hasJoinHint() { + return hint != null; + } + + @Override + public boolean hasJoinCondition() { + return !CollectionUtils.isEmpty(hashJoinConjuncts); + } +} diff --git a/regression-test/data/nereids_syntax_p0/test_join3.out b/regression-test/data/nereids_syntax_p0/test_join3.out new file mode 100644 index 0000000000..3a1541bde9 --- /dev/null +++ b/regression-test/data/nereids_syntax_p0/test_join3.out @@ -0,0 +1,28 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !join1 -- +\N \N \N \N dd 33 +\N \N cc 22 cc 23 +\N \N ee 42 \N \N +bb 11 bb 12 bb 13 + +-- !join2 -- +bb 12 bb 13 +cc 22 cc 23 + +-- !join3 -- +bb 12 bb 13 +cc 22 cc 23 +ee 42 \N \N + +-- !join4 -- +\N \N dd 33 +bb 12 bb 13 +cc 22 cc 23 +ee 42 \N \N + +-- !join7 -- +\N \N bb 2 +\N \N cc 2 +\N \N ee 2 +bb 11 \N \N + diff --git a/regression-test/suites/nereids_syntax_p0/test_join3.groovy b/regression-test/suites/nereids_syntax_p0/test_join3.groovy new file mode 100644 index 0000000000..17a0ad76c7 --- /dev/null +++ b/regression-test/suites/nereids_syntax_p0/test_join3.groovy @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("nereids_test_join3", "query,p0") { + + sql "SET enable_vectorized_engine=true" + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + + def DBname = "regression_test_join3" + sql "DROP DATABASE IF EXISTS ${DBname}" + sql "CREATE DATABASE IF NOT EXISTS ${DBname}" + sql "use ${DBname}" + + def tbName1 = "t1" + def tbName2 = "t2" + def tbName3 = "t3" + + sql """CREATE TABLE IF NOT EXISTS ${tbName1} (name varchar(255), n INTEGER) DISTRIBUTED BY HASH(name) properties("replication_num" = "1");""" + sql """CREATE TABLE IF NOT EXISTS ${tbName2} (name varchar(255), n INTEGER) DISTRIBUTED BY HASH(name) properties("replication_num" = "1");""" + sql """CREATE TABLE IF NOT EXISTS ${tbName3} (name varchar(255), n INTEGER) DISTRIBUTED BY HASH(name) properties("replication_num" = "1");""" + + sql "INSERT INTO ${tbName1} VALUES ( 'bb', 11 );" + sql "INSERT INTO ${tbName2} VALUES ( 'bb', 12 );" + sql "INSERT INTO ${tbName2} VALUES ( 'cc', 22 );" + sql "INSERT INTO ${tbName2} VALUES ( 'ee', 42 );" + sql "INSERT INTO ${tbName3} VALUES ( 'bb', 13 );" + sql "INSERT INTO ${tbName3} VALUES ( 'cc', 23 );" + sql "INSERT INTO ${tbName3} VALUES ( 'dd', 33 );" + + qt_join1 """ + SELECT * FROM ${tbName1} FULL JOIN ${tbName2} USING (name) FULL JOIN ${tbName3} USING (name) ORDER BY 1,2,3,4,5,6; + """ + qt_join2 """ + SELECT * FROM + (SELECT * FROM ${tbName2}) as s2 + INNER JOIN + (SELECT * FROM ${tbName3}) s3 + USING (name) + ORDER BY 1,2,3,4; + """ + qt_join3 """ + SELECT * FROM + (SELECT * FROM ${tbName2}) as s2 + LEFT JOIN + (SELECT * FROM ${tbName3}) s3 + USING (name) + ORDER BY 1,2,3,4; + """ + qt_join4 """ + SELECT * FROM + (SELECT * FROM ${tbName2}) as s2 + FULL JOIN + (SELECT * FROM ${tbName3}) s3 + USING (name) + ORDER BY 1,2,3,4; + """ + +// wait fix +// qt_join5 """ +// SELECT * FROM +// (SELECT name, n as s2_n, 2 as s2_2 FROM ${tbName2}) as s2 +// NATURAL INNER JOIN +// (SELECT name, n as s3_n, 3 as s3_2 FROM ${tbName3}) s3 +// ORDER BY 1,2,3,4; +// """ + +// qt_join6 """ +// SELECT * FROM +// (SELECT name, n as s1_n, 1 as s1_1 FROM ${tbName1}) as s1 +// NATURAL INNER JOIN +// (SELECT name, n as s2_n, 2 as s2_2 FROM ${tbName2}) as s2 +// NATURAL INNER JOIN +// (SELECT name, n as s3_n, 3 as s3_2 FROM ${tbName3}) s3; +// """ + + qt_join7 """ + SELECT * FROM + (SELECT name, n as s1_n FROM ${tbName1}) as s1 + FULL JOIN + (SELECT name, 2 as s2_n FROM ${tbName2}) as s2 + ON (s1_n = s2_n) + ORDER BY 1,2,3,4; + """ + + // sql "DROP DATABASE IF EXISTS ${DBname}" +} diff --git a/regression-test/suites/nereids_syntax_p0/using_join.groovy b/regression-test/suites/nereids_syntax_p0/using_join.groovy index ecbc714d9a..d53d378650 100644 --- a/regression-test/suites/nereids_syntax_p0/using_join.groovy +++ b/regression-test/suites/nereids_syntax_p0/using_join.groovy @@ -66,11 +66,11 @@ suite("nereids_using_join") { sql """INSERT INTO t2 VALUES('6', 3, 1)""" sql """INSERT INTO t2 VALUES('7', 4, 1)""" - qt_sql """ + order_qt_sql """ SELECT t1.col1 FROM t1 JOIN t2 USING (col1) """ - qt_sql """ + order_qt_sql """ SELECT t1.col1 FROM t1 JOIN t2 USING (col1, col2) """