[Feature](Nereids) support MarkJoin (#16616)

# Proposed changes
1.The new optimizer supports the combination of subquery and disjunction.In the way of MarkJoin, it behaves the same as the old optimizer. For design details see:https://emmymiao87.github.io/jekyll/update/2021/07/25/Mark-Join.html.
2.Implicit type conversion is performed when conjects are generated after subquery parsing
3.Convert the unnesting of scalarSubquery in filter from filter+join to join + Conjuncts.
This commit is contained in:
zhengshiJ
2023-03-08 14:26:24 +08:00
committed by GitHub
parent 626fbc34f9
commit aab14922af
78 changed files with 1531 additions and 192 deletions

View File

@ -504,6 +504,10 @@ public class StmtRewriter {
+ "expression: "
+ exprWithSubquery.toSql());
}
if (exprWithSubquery instanceof BinaryPredicate && (childrenContainInOrExists(exprWithSubquery))) {
throw new AnalysisException("Not support binaryOperator children at least one is in or exists subquery"
+ exprWithSubquery.toSql());
}
if (exprWithSubquery instanceof ExistsPredicate) {
// Check if we can determine the result of an ExistsPredicate during analysis.
@ -542,6 +546,16 @@ public class StmtRewriter {
}
}
private static boolean childrenContainInOrExists(Expr expr) {
boolean contain = false;
for (Expr child : expr.getChildren()) {
contain = contain || child instanceof InPredicate || child instanceof ExistsPredicate;
if (contain) {
break;
}
}
return contain;
}
/**
* Replace an ExistsPredicate that contains a subquery with a BoolLiteral if we

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.nereids.rules.analysis.ColumnAliasGenerator;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.qe.ConnectContext;
@ -28,7 +29,9 @@ import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.Maps;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import javax.annotation.concurrent.GuardedBy;
/**
@ -51,6 +54,10 @@ public class StatementContext {
private StatementBase parsedStatement;
private Set<String> columnNames;
private ColumnAliasGenerator columnAliasGenerator;
public StatementContext() {
this.connectContext = ConnectContext.get();
}
@ -111,4 +118,22 @@ public class StatementContext {
}
return supplier.get();
}
public Set<String> getColumnNames() {
return columnNames == null ? new HashSet<>() : columnNames;
}
public void setColumnNames(Set<String> columnNames) {
this.columnNames = columnNames;
}
public ColumnAliasGenerator getColumnAliasGenerator() {
return columnAliasGenerator == null
? columnAliasGenerator = new ColumnAliasGenerator(this)
: columnAliasGenerator;
}
public String generateColumnName() {
return getColumnAliasGenerator().getNextAlias();
}
}

View File

@ -57,6 +57,7 @@ import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
@ -193,6 +194,12 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
return context.findSlotRef(slotReference.getExprId());
}
@Override
public Expr visitMarkJoinReference(MarkJoinSlotReference markJoinSlotReference, PlanTranslatorContext context) {
return markJoinSlotReference.isExistsHasAgg()
? new BoolLiteral(true) : context.findSlotRef(markJoinSlotReference.getExprId());
}
@Override
public Expr visitLiteral(Literal literal, PlanTranslatorContext context) {
return literal.toLegacyLiteral();

View File

@ -60,6 +60,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -956,7 +957,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
HashJoinNode hashJoinNode = new HashJoinNode(context.nextPlanNodeId(), leftPlanRoot,
rightPlanRoot, JoinType.toJoinOperator(joinType), execEqConjuncts, Lists.newArrayList(),
null, null, null);
null, null, null, hashJoin.isMarkJoin());
PlanFragment currentFragment;
if (JoinUtils.shouldColocateJoin(physicalHashJoin)) {
@ -1012,13 +1013,15 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.forEach(s -> hashOutputSlotReferenceMap.put(s.getExprId(), s));
Map<ExprId, SlotReference> leftChildOutputMap = Maps.newHashMap();
hashJoin.child(0).getOutput().stream()
Stream.concat(hashJoin.child(0).getOutput().stream(), hashJoin.child(0).getNonUserVisibleOutput().stream())
.map(SlotReference.class::cast)
.forEach(s -> leftChildOutputMap.put(s.getExprId(), s));
context.getOutputMarkJoinSlot().stream().forEach(s -> leftChildOutputMap.put(s.getExprId(), s));
Map<ExprId, SlotReference> rightChildOutputMap = Maps.newHashMap();
hashJoin.child(1).getOutput().stream()
Stream.concat(hashJoin.child(1).getOutput().stream(), hashJoin.child(1).getNonUserVisibleOutput().stream())
.map(SlotReference.class::cast)
.forEach(s -> rightChildOutputMap.put(s.getExprId(), s));
context.getOutputMarkJoinSlot().stream().forEach(s -> rightChildOutputMap.put(s.getExprId(), s));
// translate runtime filter
context.getRuntimeTranslator().ifPresent(runtimeFilterTranslator -> runtimeFilterTranslator
.getRuntimeFilterOfHashJoinNode(physicalHashJoin)
@ -1040,6 +1043,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
SlotReference sf = leftChildOutputMap.get(context.findExprId(leftSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
leftIntermediateSlotDescriptor.add(sd);
if (sf instanceof MarkJoinSlotReference && hashJoin.getFilterConjuncts().isEmpty()) {
outputSlotReferences.add(sf);
}
}
} else if (hashJoin.getOtherJoinConjuncts().isEmpty()
&& (joinType == JoinType.RIGHT_ANTI_JOIN || joinType == JoinType.RIGHT_SEMI_JOIN)) {
@ -1076,6 +1082,14 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
}
}
if (hashJoin.getMarkJoinSlotReference().isPresent()) {
if (hashJoin.getFilterConjuncts().isEmpty()) {
outputSlotReferences.add(hashJoin.getMarkJoinSlotReference().get());
context.setOutputMarkJoinSlot(hashJoin.getMarkJoinSlotReference().get());
}
context.createSlotDesc(intermediateDescriptor, hashJoin.getMarkJoinSlotReference().get());
}
// set slots as nullable for outer join
if (joinType == JoinType.LEFT_OUTER_JOIN || joinType == JoinType.FULL_OUTER_JOIN) {
rightIntermediateSlotDescriptor.forEach(sd -> sd.setIsNullable(true));
@ -1142,7 +1156,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
NestedLoopJoinNode nestedLoopJoinNode = new NestedLoopJoinNode(context.nextPlanNodeId(),
leftFragmentPlanRoot, rightFragmentPlanRoot, tupleIds, JoinType.toJoinOperator(joinType),
null, null, null);
null, null, null, nestedLoopJoin.isMarkJoin());
if (nestedLoopJoin.getStats() != null) {
nestedLoopJoinNode.setCardinality((long) nestedLoopJoin.getStats().getRowCount());
}
@ -1157,13 +1171,17 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
}
Map<ExprId, SlotReference> leftChildOutputMap = Maps.newHashMap();
nestedLoopJoin.child(0).getOutput().stream()
Stream.concat(nestedLoopJoin.child(0).getOutput().stream(),
nestedLoopJoin.child(0).getNonUserVisibleOutput().stream())
.map(SlotReference.class::cast)
.forEach(s -> leftChildOutputMap.put(s.getExprId(), s));
context.getOutputMarkJoinSlot().stream().forEach(s -> leftChildOutputMap.put(s.getExprId(), s));
Map<ExprId, SlotReference> rightChildOutputMap = Maps.newHashMap();
nestedLoopJoin.child(1).getOutput().stream()
Stream.concat(nestedLoopJoin.child(1).getOutput().stream(),
nestedLoopJoin.child(1).getNonUserVisibleOutput().stream())
.map(SlotReference.class::cast)
.forEach(s -> rightChildOutputMap.put(s.getExprId(), s));
context.getOutputMarkJoinSlot().stream().forEach(s -> rightChildOutputMap.put(s.getExprId(), s));
// make intermediate tuple
List<SlotDescriptor> leftIntermediateSlotDescriptor = Lists.newArrayList();
List<SlotDescriptor> rightIntermediateSlotDescriptor = Lists.newArrayList();
@ -1198,6 +1216,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
.map(outputSlotReferenceMap::get)
.filter(Objects::nonNull)
.collect(Collectors.toList());
// TODO: because of the limitation of be, the VNestedLoopJoinNode will output column from both children
// in the intermediate tuple, so fe have to do the same, if be fix the problem, we can change it back.
for (SlotDescriptor leftSlotDescriptor : leftSlotDescriptors) {
@ -1207,6 +1226,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
SlotReference sf = leftChildOutputMap.get(context.findExprId(leftSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
leftIntermediateSlotDescriptor.add(sd);
if (sf instanceof MarkJoinSlotReference && nestedLoopJoin.getFilterConjuncts().isEmpty()) {
outputSlotReferences.add(sf);
}
}
for (SlotDescriptor rightSlotDescriptor : rightSlotDescriptors) {
if (!rightSlotDescriptor.isMaterialized()) {
@ -1215,6 +1237,17 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
SlotReference sf = rightChildOutputMap.get(context.findExprId(rightSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
rightIntermediateSlotDescriptor.add(sd);
if (sf instanceof MarkJoinSlotReference && nestedLoopJoin.getFilterConjuncts().isEmpty()) {
outputSlotReferences.add(sf);
}
}
if (nestedLoopJoin.getMarkJoinSlotReference().isPresent()) {
if (nestedLoopJoin.getFilterConjuncts().isEmpty()) {
outputSlotReferences.add(nestedLoopJoin.getMarkJoinSlotReference().get());
context.setOutputMarkJoinSlot(nestedLoopJoin.getMarkJoinSlotReference().get());
}
context.createSlotDesc(intermediateDescriptor, nestedLoopJoin.getMarkJoinSlotReference().get());
}
// set slots as nullable for outer join

View File

@ -29,6 +29,7 @@ import org.apache.doris.catalog.TableIf;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
@ -81,6 +82,8 @@ public class PlanTranslatorContext {
private final Map<ExprId, SlotRef> bufferedSlotRefForWindow = Maps.newHashMap();
private TupleDescriptor bufferedTupleForWindow = null;
private List<MarkJoinSlotReference> outputMarkJoinSlot = Lists.newArrayList();
public PlanTranslatorContext(CascadesContext ctx) {
this.translator = new RuntimeFilterTranslator(ctx.getRuntimeFilterContext());
}
@ -210,4 +213,12 @@ public class PlanTranslatorContext {
public DescriptorTable getDescTable() {
return descTable;
}
public void setOutputMarkJoinSlot(MarkJoinSlotReference markJoinSlotReference) {
outputMarkJoinSlot.add(markJoinSlotReference);
}
public List<MarkJoinSlotReference> getOutputMarkJoinSlot() {
return outputMarkJoinSlot;
}
}

View File

@ -84,7 +84,6 @@ public class NereidsRewriter extends BatchRewriteJob {
new AvgDistinctToSumDivCount(),
new CountDistinctRewrite(),
new NormalizeAggregate(),
new ExtractFilterFromCrossJoin()
),
@ -116,6 +115,14 @@ public class NereidsRewriter extends BatchRewriteJob {
)
),
// The rule modification needs to be done after the subquery is unnested,
// because for scalarSubQuery, the connection condition is stored in apply in the analyzer phase,
// but when normalizeAggregate is performed, the members in apply cannot be obtained,
// resulting in inconsistent output results and results in apply
topDown(
new NormalizeAggregate()
),
topDown(
new AdjustAggregateNullableForEmptySet()
),

View File

@ -432,6 +432,7 @@ public class GraphSimplifier {
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinConjuncts(),
join.getMarkJoinSlotReference(),
join.getLogicalProperties(),
join.left(),
join.right());
@ -442,6 +443,7 @@ public class GraphSimplifier {
join.getHashJoinConjuncts(),
join.getOtherJoinConjuncts(),
join.getHint(),
join.getMarkJoinSlotReference(),
join.getLogicalProperties(),
join.left(),
join.right());

View File

@ -57,6 +57,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@ -226,16 +227,19 @@ public class PlanReceiver implements AbstractReceiver {
() -> JoinUtils.getJoinOutput(joinType, left, right));
if (JoinUtils.shouldNestedLoopJoin(joinType, hashConjuncts)) {
return Lists.newArrayList(
new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts, joinProperties, left,
right),
new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, joinProperties,
new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts,
Optional.empty(), joinProperties,
left, right),
new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, Optional.empty(),
joinProperties,
right, left));
} else {
return Lists.newArrayList(
new PhysicalHashJoin<>(joinType, hashConjuncts, otherConjuncts, JoinHint.NONE, joinProperties,
left,
right),
new PhysicalHashJoin<>(joinType, hashConjuncts, otherConjuncts, JoinHint.NONE, Optional.empty(),
joinProperties,
left, right),
new PhysicalHashJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, JoinHint.NONE,
Optional.empty(),
joinProperties,
right, left));
}
@ -258,6 +262,17 @@ public class PlanReceiver implements AbstractReceiver {
return joinType;
}
private boolean extractIsMarkJoin(List<Edge> edges) {
boolean isMarkJoin = false;
JoinType joinType = null;
for (Edge edge : edges) {
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
isMarkJoin = edge.getJoin().isMarkJoin() || isMarkJoin;
joinType = edge.getJoinType();
}
return isMarkJoin;
}
@Override
public void addGroup(long bitmap, Group group) {
Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 1);
@ -322,8 +337,8 @@ public class PlanReceiver implements AbstractReceiver {
} else if (physicalPlan instanceof AbstractPhysicalJoin) {
AbstractPhysicalJoin physicalJoin = (AbstractPhysicalJoin) physicalPlan;
logicalPlan = new LogicalJoin<>(physicalJoin.getJoinType(), physicalJoin.getHashJoinConjuncts(),
physicalJoin.getOtherJoinConjuncts(), JoinHint.NONE, physicalJoin.child(0),
physicalJoin.child(1));
physicalJoin.getOtherJoinConjuncts(), JoinHint.NONE, physicalJoin.getMarkJoinSlotReference(),
physicalJoin.child(0), physicalJoin.child(1));
} else {
throw new RuntimeException("DPhyp can only handle join and project operator");
}

View File

@ -1218,6 +1218,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
ExpressionUtils.EMPTY_CONDITION,
ExpressionUtils.EMPTY_CONDITION,
JoinHint.NONE,
Optional.empty(),
left,
right);
left = withJoinRelations(left, relation);
@ -1481,6 +1482,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
condition.map(ExpressionUtils::extractConjunction)
.orElse(ExpressionUtils.EMPTY_CONDITION),
joinHint,
Optional.empty(),
last,
plan(join.relationPrimary()));
} else {

View File

@ -80,7 +80,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
Map<NamedExpression, Pair<RelationId, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
join.right().accept(this, context);
join.left().accept(this, context);
if (deniedJoinType.contains(join.getJoinType())) {
if (deniedJoinType.contains(join.getJoinType()) || join.isMarkJoin()) {
// copy to avoid bug when next call of getOutputSet()
Set<Slot> slots = join.getOutputSet();
slots.forEach(aliasTransferMap::remove);

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
@ -78,7 +79,8 @@ public class Validator extends PlanPostProcessor {
.<Set<Slot>>map(expr -> expr.collect(Slot.class::isInstance))
.flatMap(Collection::stream).collect(Collectors.toSet());
for (Slot slot : slotsUsedByFilter) {
Preconditions.checkState(childOutputSet.contains(slot));
Preconditions.checkState(childOutputSet.contains(slot)
|| slot instanceof MarkJoinSlotReference);
}
child.accept(this, context);

View File

@ -153,8 +153,8 @@ public class BindExpression implements AnalysisRuleFactory {
LogicalJoin<Plan, Plan> lj = new LogicalJoin<>(using.getJoinType() == JoinType.CROSS_JOIN
? JoinType.INNER_JOIN : using.getJoinType(),
using.getHashJoinConjuncts(),
using.getOtherJoinConjuncts(), using.getHint(), using.left(),
using.right());
using.getOtherJoinConjuncts(), using.getHint(), using.getMarkJoinSlotReference(),
using.left(), using.right());
List<Expression> unboundSlots = lj.getHashJoinConjuncts();
Set<String> slotNames = new HashSet<>();
List<Slot> leftOutput = new ArrayList<>(lj.left().getOutput());
@ -201,7 +201,8 @@ public class BindExpression implements AnalysisRuleFactory {
.map(expr -> bindFunction(expr, ctx.cascadesContext))
.collect(Collectors.toList());
return new LogicalJoin<>(join.getJoinType(),
hashJoinConjuncts, cond, join.getHint(), join.left(), join.right());
hashJoinConjuncts, cond, join.getHint(), join.getMarkJoinSlotReference(),
join.left(), join.right());
})
),
RuleType.BINDING_AGGREGATE_SLOT.build(

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.plans.Plan;
@ -71,8 +72,8 @@ public class CheckAfterRewrite extends OneAnalysisRuleFactory {
}
}
private Set<Slot> removeValidSlotsNotFromChildren(Set<Slot> virtualSlots, Set<ExprId> childrenOutput) {
return virtualSlots.stream()
private Set<Slot> removeValidSlotsNotFromChildren(Set<Slot> slots, Set<ExprId> childrenOutput) {
return slots.stream()
.filter(expr -> {
if (expr instanceof VirtualSlotReference) {
List<Expression> realExpressions = ((VirtualSlotReference) expr).getRealExpressions();
@ -85,7 +86,7 @@ public class CheckAfterRewrite extends OneAnalysisRuleFactory {
.flatMap(Set::stream)
.anyMatch(realUsedExpr -> !childrenOutput.contains(realUsedExpr.getExprId()));
} else {
return true;
return !(expr instanceof SlotNotFromChildren);
}
})
.collect(Collectors.toSet());

View File

@ -0,0 +1,39 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/apache/impala/blob/branch-2.9.0/fe/src/main/java/org/apache/impala/ColumnAliasGenerator.java
// and modified by Doris
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.common.AliasGenerator;
import org.apache.doris.nereids.StatementContext;
import com.google.common.base.Preconditions;
/**
* Generate the table name required in the rewrite process.
*/
public class ColumnAliasGenerator extends AliasGenerator {
private static final String DEFAULT_COL_ALIAS_PREFIX = "$c$";
public ColumnAliasGenerator(StatementContext statementContext) {
Preconditions.checkNotNull(statementContext);
aliasPrefix = DEFAULT_COL_ALIAS_PREFIX;
usedAliases.addAll(statementContext.getColumnNames());
}
}

View File

@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
@ -187,7 +188,8 @@ class FunctionBinder extends DefaultExpressionRewriter<CascadesContext> {
Expression right = compoundPredicate.right().accept(this, context);
Expression ret = compoundPredicate.withChildren(left, right);
ret.children().forEach(e -> {
if (!e.getDataType().isBooleanType() && !e.getDataType().isNullType()) {
if (!e.getDataType().isBooleanType() && !e.getDataType().isNullType()
&& !(e instanceof SubqueryExpr)) {
throw new AnalysisException(String.format(
"Operand '%s' part of predicate " + "'%s' should return type 'BOOLEAN' but "
+ "returns type '%s'.",

View File

@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import com.google.common.base.Preconditions;
import org.apache.commons.lang.StringUtils;
import java.util.List;
@ -69,12 +70,15 @@ class SlotBinder extends SubExprAnalyzer {
public Expression visitUnboundAlias(UnboundAlias unboundAlias, CascadesContext context) {
Expression child = unboundAlias.child().accept(this, context);
if (unboundAlias.getAlias().isPresent()) {
collectColumnNames(unboundAlias.getAlias().get());
return new Alias(child, unboundAlias.getAlias().get());
}
if (child instanceof NamedExpression) {
collectColumnNames(((NamedExpression) child).getName());
return new Alias(child, ((NamedExpression) child).getName());
} else {
// TODO: resolve aliases
collectColumnNames(child.toSql());
return new Alias(child, child.toSql());
}
}
@ -98,7 +102,8 @@ class SlotBinder extends SubExprAnalyzer {
// if unbound finally, check will throw exception
return unboundSlot;
case 1:
if (!foundInThisScope) {
if (!foundInThisScope
&& !getScope().getOuterScope().get().getCorrelatedSlots().contains(bounded.get(0))) {
getScope().getOuterScope().get().getCorrelatedSlots().add(bounded.get(0));
}
return bounded.get(0);
@ -218,4 +223,11 @@ class SlotBinder extends SubExprAnalyzer {
+ StringUtils.join(nameParts, "."));
}).collect(Collectors.toList());
}
private void collectColumnNames(String columnName) {
Preconditions.checkNotNull(getCascadesContext());
if (!getCascadesContext().getStatementContext().getColumnNames().add(columnName)) {
throw new AnalysisException("Collect column name failed, columnName : " + columnName);
}
}
}

View File

@ -20,6 +20,8 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
@ -81,6 +83,7 @@ class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
AnalyzedResult analyzedResult = analyzeSubquery(expr);
checkOutputColumn(analyzedResult.getLogicalPlan());
checkHasNotAgg(analyzedResult);
checkHasGroupBy(analyzedResult);
checkRootIsLimit(analyzedResult);
@ -101,6 +104,21 @@ class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
return new ScalarSubquery(analyzedResult.getLogicalPlan(), analyzedResult.getCorrelatedSlots());
}
@Override
public Expression visitBinaryOperator(BinaryOperator binaryOperator, CascadesContext context) {
if (childrenAtLeastOneInOrExistsSub(binaryOperator) && (binaryOperator instanceof ComparisonPredicate)) {
throw new AnalysisException("Not support binaryOperator children at least one is in or exists subquery");
}
return visit(binaryOperator, context);
}
private boolean childrenAtLeastOneInOrExistsSub(BinaryOperator binaryOperator) {
return binaryOperator.left().anyMatch(InSubquery.class::isInstance)
|| binaryOperator.left().anyMatch(Exists.class::isInstance)
|| binaryOperator.right().anyMatch(InSubquery.class::isInstance)
|| binaryOperator.right().anyMatch(Exists.class::isInstance);
}
private void checkOutputColumn(LogicalPlan plan) {
if (plan.getOutput().size() != 1) {
throw new AnalysisException("Multiple columns returned by subquery are not yet supported. Found "
@ -129,6 +147,16 @@ class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
}
}
private void checkHasNotAgg(AnalyzedResult analyzedResult) {
if (!analyzedResult.isCorrelated()) {
return;
}
if (analyzedResult.hasAgg()) {
throw new AnalysisException("Unsupported correlated subquery with grouping and/or aggregation "
+ analyzedResult.getLogicalPlan());
}
}
private void checkRootIsLimit(AnalyzedResult analyzedResult) {
if (!analyzedResult.isCorrelated()) {
return;

View File

@ -18,30 +18,43 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
@ -58,23 +71,55 @@ public class SubqueryToApply implements AnalysisRuleFactory {
RuleType.FILTER_SUBQUERY_TO_APPLY.build(
logicalFilter().thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
Set<SubqueryExpr> subqueryExprs = filter.getPredicate().collect(SubqueryExpr.class::isInstance);
if (subqueryExprs.isEmpty()) {
ImmutableList<Set> subqueryExprsList = filter.getConjuncts().stream()
.map(e -> (Set) e.collect(SubqueryExpr.class::isInstance))
.collect(ImmutableList.toImmutableList());
if (subqueryExprsList.stream()
.flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) {
return filter;
}
// first step: Replace the subquery of predicate in LogicalFilter
// second step: Replace subquery with LogicalApply
return new LogicalFilter<>(new ReplaceSubquery().replace(filter.getConjuncts()),
subqueryToApply(
subqueryExprs, filter.child(), ctx.cascadesContext
));
List<Expression> oldConjuncts = ImmutableList.copyOf(filter.getConjuncts());
ImmutableList.Builder<Expression> newConjuncts = new ImmutableList.Builder<>();
LogicalPlan applyPlan = null;
LogicalPlan tmpPlan = (LogicalPlan) filter.child();
// Subquery traversal with the conjunct of and as the granularity.
for (int i = 0; i < subqueryExprsList.size(); ++i) {
Set<SubqueryExpr> subqueryExprs = subqueryExprsList.get(i);
if (subqueryExprs.isEmpty()) {
newConjuncts.add(oldConjuncts.get(i));
continue;
}
// first step: Replace the subquery of predicate in LogicalFilter
// second step: Replace subquery with LogicalApply
ReplaceSubquery replaceSubquery = new ReplaceSubquery(
ctx.statementContext, false);
SubqueryContext context = new SubqueryContext(subqueryExprs);
Expression conjunct = replaceSubquery.replace(oldConjuncts.get(i), context);
applyPlan = subqueryToApply(subqueryExprs.stream()
.collect(ImmutableList.toImmutableList()), tmpPlan,
context.getSubqueryToMarkJoinSlot(),
context.getSubqueryCorrespondingConjunct(), ctx.cascadesContext,
Optional.of(conjunct), false);
tmpPlan = applyPlan;
if (!(subqueryExprs.size() == 1
&& subqueryExprs.stream().anyMatch(ScalarSubquery.class::isInstance))) {
newConjuncts.add(conjunct);
}
}
Set<Expression> conjects = new HashSet<>();
conjects.addAll(newConjuncts.build());
return new LogicalFilter<>(conjects, applyPlan);
})
),
RuleType.PROJECT_SUBQUERY_TO_APPLY.build(
logicalProject().thenApply(ctx -> {
LogicalProject<Plan> project = ctx.root;
Set<SubqueryExpr> subqueryExprs = new HashSet<>();
Set<SubqueryExpr> subqueryExprs = new LinkedHashSet<>();
project.getProjects().stream()
.filter(Alias.class::isInstance)
.map(Alias.class::cast)
@ -86,40 +131,97 @@ public class SubqueryToApply implements AnalysisRuleFactory {
return project;
}
SubqueryContext context = new SubqueryContext(subqueryExprs);
return new LogicalProject(project.getProjects().stream()
.map(p -> p.withChildren(new ReplaceSubquery().replace(p)))
.map(p -> p.withChildren(
new ReplaceSubquery(ctx.statementContext, true)
.replace(p, context)))
.collect(ImmutableList.toImmutableList()),
subqueryToApply(
subqueryExprs, project.child(), ctx.cascadesContext
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
(LogicalPlan) project.child(),
context.getSubqueryToMarkJoinSlot(), context.getSubqueryCorrespondingConjunct(),
ctx.cascadesContext,
Optional.empty(), true
));
})
)
);
}
private Plan subqueryToApply(Set<SubqueryExpr> subqueryExprs,
Plan childPlan, CascadesContext ctx) {
Plan tmpPlan = childPlan;
for (SubqueryExpr subqueryExpr : subqueryExprs) {
private LogicalPlan subqueryToApply(List<SubqueryExpr> subqueryExprs, LogicalPlan childPlan,
Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot,
Map<SubqueryExpr, Expression> subqueryCorrespondingConject,
CascadesContext ctx,
Optional<Expression> conjunct, boolean isProject) {
LogicalPlan tmpPlan = childPlan;
for (int i = 0; i < subqueryExprs.size(); ++i) {
SubqueryExpr subqueryExpr = subqueryExprs.get(i);
if (!ctx.subqueryIsAnalyzed(subqueryExpr)) {
tmpPlan = addApply(subqueryExpr, tmpPlan, ctx);
tmpPlan = addApply(subqueryExpr, tmpPlan,
subqueryToMarkJoinSlot, subqueryCorrespondingConject, ctx, conjunct,
isProject, subqueryExprs.size() == 1);
}
}
return tmpPlan;
}
private LogicalPlan addApply(SubqueryExpr subquery, Plan childPlan, CascadesContext ctx) {
private LogicalPlan addApply(SubqueryExpr subquery, LogicalPlan childPlan,
Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot,
Map<SubqueryExpr, Expression> subqueryCorrespondingConject,
CascadesContext ctx, Optional<Expression> conjunct,
boolean isProject, boolean singleSubquery) {
ctx.setSubqueryExprIsAnalyzed(subquery, true);
LogicalApply newApply = new LogicalApply(
subquery.getCorrelateSlots(),
subquery, Optional.empty(), childPlan, subquery.getQueryPlan());
List<Slot> projects = new ArrayList<>(childPlan.getOutput());
subquery, Optional.empty(),
subqueryToMarkJoinSlot.get(subquery),
mergeScalarSubConjectAndFilterConject(
subquery, subqueryCorrespondingConject,
conjunct, isProject, singleSubquery),
childPlan, subquery.getQueryPlan());
List<NamedExpression> projects = new ArrayList<>(childPlan.getOutput());
if (subquery instanceof ScalarSubquery) {
projects.add(subquery.getQueryPlan().getOutput().get(0));
}
return new LogicalProject(projects, newApply);
}
private boolean checkSingleScalarWithOr(SubqueryExpr subquery,
Optional<Expression> conjunct) {
return subquery instanceof ScalarSubquery
&& conjunct.isPresent() && conjunct.get() instanceof Or
&& subquery.getCorrelateSlots().isEmpty();
}
/**
* For a single scalarSubQuery, when there is a disjunction,
* directly use all connection conditions as the join conjunct of scalarSubQuery.
* e.g.
* select * from t1 where k1 > scalarSub(sum(c1)) or k2 > 10;
* LogicalJoin(otherConjunct[k1 > sum(c1) or k2 > 10])
*
* For other scalarSubQuery, you only need to use the connection as the join conjunct.
* e.g.
* select * from t1 where k1 > scalarSub(sum(c1)) or k2 in inSub(c2) or k2 > 10;
* LogicalFilter($c$1 or $c$2 or k2 > 10)
* LogicalJoin(otherConjunct[k2 = c2]) ---> inSub
* LogicalJoin(otherConjunct[k1 > sum(c1)]) ---> scalarSub
*/
private Optional<Expression> mergeScalarSubConjectAndFilterConject(
SubqueryExpr subquery,
Map<SubqueryExpr, Expression> subqueryCorrespondingConject,
Optional<Expression> conjunct,
boolean isProject,
boolean singleSubquery) {
if (singleSubquery && checkSingleScalarWithOr(subquery, conjunct)) {
return conjunct;
} else if (subqueryCorrespondingConject.containsKey(subquery) && !isProject) {
return Optional.of(subqueryCorrespondingConject.get(subquery));
}
return Optional.empty();
}
/**
* The Subquery in the LogicalFilter will change to LogicalApply, so we must replace the origin Subquery.
* LogicalFilter(predicate(contain subquery)) -> LogicalFilter(predicate(not contain subquery)
@ -133,32 +235,207 @@ public class SubqueryToApply implements AnalysisRuleFactory {
*
* after:
* 1.filter(t1.a = b);
* 2.filter(True);
* 3.filter(True);
* 2.isMarkJoin ? filter(MarkJoinSlotReference) : filter(True);
* 3.isMarkJoin ? filter(MarkJoinSlotReference) : filter(True);
*/
private static class ReplaceSubquery extends DefaultExpressionRewriter<Void> {
public Set<Expression> replace(Set<Expression> expressions) {
return expressions.stream().map(expr -> expr.accept(this, null))
private static class ReplaceSubquery extends DefaultExpressionRewriter<SubqueryContext> {
private final StatementContext statementContext;
private boolean isMarkJoin;
private final boolean isProject;
public ReplaceSubquery(StatementContext statementContext,
boolean isProject) {
this.statementContext = Objects.requireNonNull(statementContext, "statementContext can't be null");
this.isProject = isProject;
}
public Set<Expression> replace(Set<Expression> expressions, SubqueryContext subqueryContext) {
return expressions.stream().map(expr -> expr.accept(this, subqueryContext))
.collect(ImmutableSet.toImmutableSet());
}
public Expression replace(Expression expressions) {
return expressions.accept(this, null);
public Expression replace(Expression expressions, SubqueryContext subqueryContext) {
return expressions.accept(this, subqueryContext);
}
@Override
public Expression visitExistsSubquery(Exists exists, Void context) {
return BooleanLiteral.TRUE;
public Expression visitExistsSubquery(Exists exists, SubqueryContext context) {
// The result set when NULL is specified in the subquery and still evaluates to TRUE by using EXISTS
// When the number of rows returned is empty, agg will return null, so if there is more agg,
// it will always consider the returned result to be true
MarkJoinSlotReference markJoinSlotReference;
if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance)) {
markJoinSlotReference =
new MarkJoinSlotReference(statementContext.generateColumnName(), true);
} else {
markJoinSlotReference =
new MarkJoinSlotReference(statementContext.generateColumnName());
}
if (isMarkJoin) {
context.setSubqueryToMarkJoinSlot(exists, Optional.of(markJoinSlotReference));
}
return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE;
}
@Override
public Expression visitInSubquery(InSubquery in, Void context) {
return BooleanLiteral.TRUE;
public Expression visitInSubquery(InSubquery in, SubqueryContext context) {
MarkJoinSlotReference markJoinSlotReference =
new MarkJoinSlotReference(statementContext.generateColumnName());
if (isMarkJoin) {
context.setSubqueryToMarkJoinSlot(in, Optional.of(markJoinSlotReference));
}
return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE;
}
@Override
public Expression visitScalarSubquery(ScalarSubquery scalar, Void context) {
return scalar.getQueryPlan().getOutput().get(0);
public Expression visitScalarSubquery(ScalarSubquery scalar, SubqueryContext context) {
context.setSubqueryCorrespondingConject(scalar, scalar.getQueryPlan().getOutput().get(0));
// When there is only one scalarSubQuery and CorrelateSlots is empty
// it will not be processed by MarkJoin, so it can be returned directly
if (context.onlySingleSubquery() && scalar.getCorrelateSlots().isEmpty()) {
return scalar.getQueryPlan().getOutput().get(0);
}
MarkJoinSlotReference markJoinSlotReference =
new MarkJoinSlotReference(statementContext.generateColumnName());
if (isMarkJoin) {
context.setSubqueryToMarkJoinSlot(scalar, Optional.of(markJoinSlotReference));
}
return isMarkJoin ? markJoinSlotReference : scalar.getQueryPlan().getOutput().get(0);
}
@Override
public Expression visitNot(Not not, SubqueryContext context) {
// Need to re-update scalarSubQuery unequal conditions into subqueryCorrespondingConject
if (not.child() instanceof BinaryOperator
&& (((BinaryOperator) not.child()).left() instanceof ScalarSubquery
|| ((BinaryOperator) not.child()).right() instanceof ScalarSubquery)) {
Expression newChild = replace(not.child(), context);
ScalarSubquery subquery = ((BinaryOperator) not.child()).left() instanceof ScalarSubquery
? (ScalarSubquery) ((BinaryOperator) not.child()).left()
: (ScalarSubquery) ((BinaryOperator) not.child()).right();
context.updateSubqueryCorrespondingConjunctInNot(subquery);
return context.getSubqueryToMarkJoinSlotValue(subquery).isPresent() ? newChild : new Not(newChild);
}
return visit(not, context);
}
@Override
public Expression visitBinaryOperator(BinaryOperator binaryOperator, SubqueryContext context) {
boolean atLeastOneChildIsScalarSubquery =
binaryOperator.left() instanceof ScalarSubquery || binaryOperator.right() instanceof ScalarSubquery;
boolean currentMarkJoin = ((binaryOperator.left().anyMatch(SubqueryExpr.class::isInstance)
|| binaryOperator.right().anyMatch(SubqueryExpr.class::isInstance))
&& (binaryOperator instanceof Or)) || isMarkJoin;
isMarkJoin = currentMarkJoin;
Expression left = replace(binaryOperator.left(), context);
isMarkJoin = currentMarkJoin;
Expression right = replace(binaryOperator.right(), context);
if (atLeastOneChildIsScalarSubquery) {
return context.replaceBinaryOperator(binaryOperator, left, right, isProject);
}
return binaryOperator.withChildren(left, right);
}
}
/**
* subqueryToMarkJoinSlot: The markJoinSlot corresponding to each subquery.
* rule:
* For inSubquery and exists: it will be directly replaced by markSlotReference
* e.g.
* logicalFilter(predicate=exists) ---> logicalFilter(predicate=$c$1)
* For scalarSubquery: will replace the connected ComparisonPredicate with markSlotReference
* e.g.
* logicalFilter(predicate=k1 > scalarSubquery) ---> logicalFilter(predicate=$c$1)
*
* subqueryCorrespondingConject: Record the conject corresponding to the subquery.
* rule:
*
*
*/
private static class SubqueryContext {
private final Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot;
private final Map<SubqueryExpr, Expression> subqueryCorrespondingConjunct;
public SubqueryContext(Set<SubqueryExpr> subqueryExprs) {
this.subqueryToMarkJoinSlot = new LinkedHashMap<>(subqueryExprs.size());
this.subqueryCorrespondingConjunct = new LinkedHashMap<>(subqueryExprs.size());
subqueryExprs.forEach(subqueryExpr -> subqueryToMarkJoinSlot.put(subqueryExpr, Optional.empty()));
}
public Map<SubqueryExpr, Optional<MarkJoinSlotReference>> getSubqueryToMarkJoinSlot() {
return subqueryToMarkJoinSlot;
}
public Map<SubqueryExpr, Expression> getSubqueryCorrespondingConjunct() {
return subqueryCorrespondingConjunct;
}
public Optional<MarkJoinSlotReference> getSubqueryToMarkJoinSlotValue(SubqueryExpr subqueryExpr) {
return subqueryToMarkJoinSlot.get(subqueryExpr);
}
public void setSubqueryToMarkJoinSlot(SubqueryExpr subquery,
Optional<MarkJoinSlotReference> markJoinSlotReference) {
subqueryToMarkJoinSlot.put(subquery, markJoinSlotReference);
}
public void setSubqueryCorrespondingConject(SubqueryExpr subquery,
Expression expression) {
subqueryCorrespondingConjunct.put(subquery, expression);
}
public boolean onlySingleSubquery() {
return subqueryToMarkJoinSlot.size() == 1;
}
public void updateSubqueryCorrespondingConjunctInNot(SubqueryExpr subquery) {
if (subqueryCorrespondingConjunct.containsKey(subquery)) {
subqueryCorrespondingConjunct.replace(subquery,
new Not(subqueryCorrespondingConjunct.get(subquery)));
}
}
/**
* For scalarSubQuery and MarkJoin, it will be replaced by markSlotReference
* e.g.
* logicalFilter(predicate=k1 > scalarSub or exists)
* -->
* logicalFilter(predicate=$c$1 or $c$2)
*
* For non-MarkJoin scalarSubQuery, do implicit type conversion.
* e.g.
* logicalFilter(predicate=k1 > scalarSub(sum(k2)))
* -->
* logicalFilter(predicate=Cast(k1[#0] as BIGINT) = sum(k2)[#1])
*/
public Expression replaceBinaryOperator(BinaryOperator binaryOperator,
Expression left,
Expression right,
boolean isProject) {
boolean leftIsScalar = binaryOperator.left() instanceof ScalarSubquery;
ScalarSubquery subquery = leftIsScalar
? (ScalarSubquery) binaryOperator.left() : (ScalarSubquery) binaryOperator.right();
// Perform implicit type conversion on the connection condition of scalarSubQuery,
// and record the result in subqueryCorrespondingConjunct
Expression newLeft = leftIsScalar && subqueryToMarkJoinSlot.get(subquery).isPresent()
? ((ScalarSubquery) binaryOperator.left()).getQueryPlan().getOutput().get(0) : left;
Expression newRight = !leftIsScalar && subqueryToMarkJoinSlot.get(subquery).isPresent()
? ((ScalarSubquery) binaryOperator.right()).getQueryPlan().getOutput().get(0) : right;
Expression newBinary = TypeCoercionUtils.processComparisonPredicate(
(ComparisonPredicate) binaryOperator.withChildren(newLeft, newRight), newLeft, newRight);
subqueryCorrespondingConjunct.put(subquery,
(isProject ? (leftIsScalar ? newLeft : newRight) : newBinary));
if (subqueryToMarkJoinSlot.get(subquery).isPresent()) {
return subqueryToMarkJoinSlot.get(subquery).get();
}
return newBinary;
}
}
}

View File

@ -52,6 +52,7 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory {
return innerLogicalJoin(innerLogicalJoin(), group())
.when(topJoin -> checkReorder(topJoin, topJoin.left()))
.whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();
@ -93,7 +94,8 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory {
public static boolean checkReorder(LogicalJoin<? extends Plan, GroupPlan> topJoin,
LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
return !bottomJoin.getJoinReorderContext().hasCommuteZigZag()
&& !topJoin.getJoinReorderContext().hasLAsscom();
&& !topJoin.getJoinReorderContext().hasLAsscom()
&& (!bottomJoin.isMarkJoin() && !topJoin.isMarkJoin());
}
/**

View File

@ -58,6 +58,7 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory {
return innerLogicalJoin(logicalProject(innerLogicalJoin()), group())
.when(topJoin -> InnerJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
.then(topJoin -> {
/* ********** init ********** */

View File

@ -53,6 +53,7 @@ public class InnerJoinLeftAssociate extends OneExplorationRuleFactory {
return innerLogicalJoin(group(), innerLogicalJoin())
.when(InnerJoinLeftAssociate::checkReorder)
.whenNot(join -> join.hasJoinHint() || join.right().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.right().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right();
GroupPlan a = topJoin.left();

View File

@ -52,6 +52,7 @@ public class InnerJoinRightAssociate extends OneExplorationRuleFactory {
return innerLogicalJoin(innerLogicalJoin(), group())
.when(InnerJoinRightAssociate::checkReorder)
.whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();

View File

@ -46,6 +46,7 @@ public class JoinCommute extends OneExplorationRuleFactory {
return logicalJoin()
.when(join -> check(swapType, join))
.whenNot(LogicalJoin::hasJoinHint)
.whenNot(LogicalJoin::isMarkJoin)
.then(join -> {
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(
join.getJoinType().swap(),

View File

@ -55,6 +55,7 @@ public class JoinExchange extends OneExplorationRuleFactory {
return innerLogicalJoin(innerLogicalJoin(), innerLogicalJoin())
.when(JoinExchange::checkReorder)
.whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint() || join.right().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin() || join.right().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> leftJoin = topJoin.left();
LogicalJoin<GroupPlan, GroupPlan> rightJoin = topJoin.right();

View File

@ -57,6 +57,7 @@ public class OuterJoinAssoc extends OneExplorationRuleFactory {
.when(join -> VALID_TYPE_PAIR_SET.contains(Pair.of(join.left().getJoinType(), join.getJoinType())))
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left()))
.when(topJoin -> checkCondition(topJoin, topJoin.left().left().getOutputSet()))
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();

View File

@ -56,6 +56,7 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
Pair.of(join.left().child().getJoinType(), join.getJoinType())))
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> OuterJoinAssoc.checkCondition(join, join.left().child().left().getOutputSet()))
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
.then(topJoin -> {

View File

@ -64,6 +64,7 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory {
.when(topJoin -> checkReorder(topJoin, topJoin.left()))
.whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint())
.when(topJoin -> checkCondition(topJoin, topJoin.left().right().getOutputExprIdSet()))
.whenNot(LogicalJoin::isMarkJoin)
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();

View File

@ -59,6 +59,7 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
Pair.of(join.left().child().getJoinType(), join.getJoinType())))
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
.then(topJoin -> {
/* ********** init ********** */

View File

@ -61,6 +61,7 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
.whenNot(topJoin -> topJoin.left().getJoinType().isSemiOrAntiJoin())
.when(this::conditionChecker)
.whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.left().hasJoinHint())
.whenNot(LogicalJoin::isMarkJoin)
.then(topSemiJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topSemiJoin.left();
GroupPlan a = bottomJoin.left();

View File

@ -67,6 +67,7 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
|| topJoin.left().child().getJoinType().isRightOuterJoin())))
.whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin())
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
.then(topSemiJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topSemiJoin.left();

View File

@ -63,6 +63,7 @@ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
return logicalJoin(logicalJoin(), group())
.when(this::typeChecker)
.whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint())
.whenNot(LogicalJoin::isMarkJoin)
.then(topJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();

View File

@ -54,6 +54,7 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory
.when(this::typeChecker)
.when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
.then(topSemi -> {
LogicalJoin<GroupPlan, GroupPlan> bottomSemi = topSemi.left().child();

View File

@ -204,7 +204,8 @@ public class ExpressionRewrite implements RewriteRuleFactory {
return join;
}
return new LogicalJoin<>(join.getJoinType(), rewriteHashJoinConjuncts,
rewriteOtherJoinConjuncts, join.getHint(), join.left(), join.right());
rewriteOtherJoinConjuncts, join.getHint(), join.getMarkJoinSlotReference(),
join.left(), join.right());
}).toRule(RuleType.REWRITE_JOIN_EXPRESSION);
}
}

View File

@ -35,6 +35,7 @@ public class LogicalJoinToHashJoin extends OneImplementationRuleFactory {
join.getHashJoinConjuncts(),
join.getOtherJoinConjuncts(),
join.getHint(),
join.getMarkJoinSlotReference(),
join.getLogicalProperties(),
join.left(),
join.right())

View File

@ -34,6 +34,7 @@ public class LogicalJoinToNestedLoopJoin extends OneImplementationRuleFactory {
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinConjuncts(),
join.getMarkJoinSlotReference(),
join.getLogicalProperties(),
join.left(),
join.right())

View File

@ -91,20 +91,31 @@ public class ExistsApplyToJoin extends OneRewriteRuleFactory {
private Plan correlatedToJoin(LogicalApply apply) {
Optional<Expression> correlationFilter = apply.getCorrelationFilter();
Expression predicate = null;
if (correlationFilter.isPresent() && apply.getSubCorrespondingConject().isPresent()) {
predicate = ExpressionUtils.and(correlationFilter.get(),
(Expression) apply.getSubCorrespondingConject().get());
} else if (apply.getSubCorrespondingConject().isPresent()) {
predicate = (Expression) apply.getSubCorrespondingConject().get();
} else if (correlationFilter.isPresent()) {
predicate = correlationFilter.get();
}
if (((Exists) apply.getSubqueryExpr()).isNot()) {
return new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, ExpressionUtils.EMPTY_CONDITION,
correlationFilter
.map(ExpressionUtils::extractConjunction)
.orElse(ExpressionUtils.EMPTY_CONDITION),
predicate != null
? ExpressionUtils.extractConjunction(predicate)
: ExpressionUtils.EMPTY_CONDITION,
JoinHint.NONE,
apply.getMarkJoinSlotReference(),
(LogicalPlan) apply.left(), (LogicalPlan) apply.right());
} else {
return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, ExpressionUtils.EMPTY_CONDITION,
correlationFilter
.map(ExpressionUtils::extractConjunction)
.orElse(ExpressionUtils.EMPTY_CONDITION),
predicate != null
? ExpressionUtils.extractConjunction(predicate)
: ExpressionUtils.EMPTY_CONDITION,
JoinHint.NONE,
apply.getMarkJoinSlotReference(),
(LogicalPlan) apply.left(), (LogicalPlan) apply.right());
}
}
@ -122,7 +133,10 @@ public class ExistsApplyToJoin extends OneRewriteRuleFactory {
Alias alias = new Alias(new Count(), "count(*)");
LogicalAggregate newAgg = new LogicalAggregate<>(new ArrayList<>(),
ImmutableList.of(alias), newLimit);
LogicalJoin newJoin = new LogicalJoin<>(JoinType.CROSS_JOIN,
LogicalJoin newJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, ExpressionUtils.EMPTY_CONDITION,
unapply.getSubCorrespondingConject().isPresent()
? ExpressionUtils.extractConjunction((Expression) unapply.getSubCorrespondingConject().get())
: ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, unapply.getMarkJoinSlotReference(),
(LogicalPlan) unapply.left(), newAgg);
return new LogicalFilter<>(ImmutableSet.of(new EqualTo(newAgg.getOutput().get(0),
new IntegerLiteral(0))), newJoin);
@ -130,6 +144,10 @@ public class ExistsApplyToJoin extends OneRewriteRuleFactory {
private Plan unCorrelatedExist(LogicalApply unapply) {
LogicalLimit newLimit = new LogicalLimit<>(1, 0, LimitPhase.ORIGIN, (LogicalPlan) unapply.right());
return new LogicalJoin<>(JoinType.CROSS_JOIN, (LogicalPlan) unapply.left(), newLimit);
return new LogicalJoin<>(JoinType.CROSS_JOIN, ExpressionUtils.EMPTY_CONDITION,
unapply.getSubCorrespondingConject().isPresent()
? ExpressionUtils.extractConjunction((Expression) unapply.getSubCorrespondingConject().get())
: ExpressionUtils.EMPTY_CONDITION,
JoinHint.NONE, unapply.getMarkJoinSlotReference(), (LogicalPlan) unapply.left(), newLimit);
}
}

View File

@ -42,7 +42,7 @@ public class ExtractFilterFromCrossJoin extends OneRewriteRuleFactory {
.then(join -> {
LogicalJoin<Plan, Plan> newJoin = new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, join.getHint(),
join.left(), join.right());
join.getMarkJoinSlotReference(), join.left(), join.right());
Set<Expression> predicates = Stream.concat(join.getHashJoinConjuncts().stream(),
join.getOtherJoinConjuncts().stream())
.collect(ImmutableSet.toImmutableSet());

View File

@ -74,6 +74,7 @@ public class FindHashConditionForJoin extends OneRewriteRuleFactory {
combinedHashJoinConjuncts,
remainedNonHashJoinConjuncts,
join.getHint(),
join.getMarkJoinSlotReference(),
join.left(), join.right());
}).toRule(RuleType.FIND_HASH_CONDITION_FOR_JOIN);
}

View File

@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.types.BitmapType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.Lists;
@ -46,14 +47,19 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
public Rule build() {
return logicalApply().when(LogicalApply::isIn).then(apply -> {
Expression predicate;
Expression left = ((InSubquery) apply.getSubqueryExpr()).getCompareExpr();
Expression right = apply.right().getOutput().get(0);
if (apply.isCorrelated()) {
predicate = ExpressionUtils.and(
new EqualTo(((InSubquery) apply.getSubqueryExpr()).getCompareExpr(),
apply.right().getOutput().get(0)),
TypeCoercionUtils.processComparisonPredicate(
new EqualTo(left, right), left, right),
apply.getCorrelationFilter().get());
} else {
predicate = new EqualTo(((InSubquery) apply.getSubqueryExpr()).getCompareExpr(),
apply.right().getOutput().get(0));
predicate = TypeCoercionUtils.processComparisonPredicate(new EqualTo(left, right), left, right);
}
if (apply.getSubCorrespondingConject().isPresent()) {
predicate = ExpressionUtils.and(predicate, apply.getSubCorrespondingConject().get());
}
//TODO nereids should support bitmap runtime filter in future
@ -67,12 +73,12 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
predicate.nullable() ? JoinType.NULL_AWARE_LEFT_ANTI_JOIN : JoinType.LEFT_ANTI_JOIN,
Lists.newArrayList(),
conjuncts,
JoinHint.NONE,
JoinHint.NONE, apply.getMarkJoinSlotReference(),
apply.left(), apply.right());
} else {
return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(),
conjuncts,
JoinHint.NONE,
JoinHint.NONE, apply.getMarkJoinSlotReference(),
apply.left(), apply.right());
}
}).toRule(RuleType.IN_APPLY_TO_JOIN);

View File

@ -79,7 +79,8 @@ public class PullUpCorrelatedFilterUnderApplyAggregateProject extends OneRewrite
LogicalFilter newFilter = new LogicalFilter<>(filter.getConjuncts(), newProject);
LogicalAggregate newAgg = agg.withChildren(ImmutableList.of(newFilter));
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
apply.getCorrelationFilter(), apply.left(), newAgg);
apply.getCorrelationFilter(), apply.getMarkJoinSlotReference(),
apply.getSubCorrespondingConject(), apply.left(), newAgg);
}).toRule(RuleType.PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT);
}
}

View File

@ -57,7 +57,8 @@ public class PullUpProjectUnderApply extends OneRewriteRuleFactory {
.then(apply -> {
LogicalProject<Plan> project = apply.right();
LogicalApply newCorrelate = new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
apply.getCorrelationFilter(), apply.left(), project.child());
apply.getCorrelationFilter(), apply.getMarkJoinSlotReference(),
apply.getSubCorrespondingConject(), apply.left(), project.child());
List<NamedExpression> newProjects = new ArrayList<>();
newProjects.addAll(apply.left().getOutput());
if (apply.getSubqueryExpr() instanceof ScalarSubquery) {

View File

@ -40,6 +40,7 @@ public class PushFilterInsideJoin extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFilter(logicalJoin())
.whenNot(filter -> filter.child().isMarkJoin())
// TODO: current just handle cross/inner join.
.when(filter -> filter.child().getJoinType().isCrossJoin()
|| filter.child().getJoinType().isInnerJoin())
@ -48,7 +49,8 @@ public class PushFilterInsideJoin extends OneRewriteRuleFactory {
LogicalJoin<Plan, Plan> join = filter.child();
otherConditions.addAll(join.getOtherJoinConjuncts());
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
otherConditions, join.getHint(), join.left(), join.right());
otherConditions, join.getHint(), join.getMarkJoinSlotReference(),
join.left(), join.right());
}).toRule(RuleType.PUSH_FILTER_INSIDE_JOIN);
}
}

View File

@ -139,6 +139,7 @@ public class PushdownFilterThroughJoin extends OneRewriteRuleFactory {
join.getHashJoinConjuncts(),
joinConditions,
join.getHint(),
join.getMarkJoinSlotReference(),
PlanUtils.filterOrSelf(leftPredicates, join.left()),
PlanUtils.filterOrSelf(rightPredicates, join.right())));
}).toRule(RuleType.PUSHDOWN_FILTER_THROUGH_JOIN);

View File

@ -90,7 +90,7 @@ public class PushdownJoinOtherCondition extends OneRewriteRuleFactory {
Plan right = PlanUtils.filterOrSelf(rightConjuncts, join.right());
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
remainingOther, join.getHint(), left, right);
remainingOther, join.getHint(), join.getMarkJoinSlotReference(), left, right);
}).toRule(RuleType.PUSHDOWN_JOIN_OTHER_CONDITION);
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinHint.JoinHintType;
import org.apache.doris.nereids.trees.plans.JoinType;
@ -45,6 +46,7 @@ import com.google.common.collect.Maps;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@ -80,12 +82,12 @@ public class ReorderJoin extends OneRewriteRuleFactory {
}
LogicalFilter<Plan> filter = ctx.root;
Map<Plan, JoinHintType> planToHintType = Maps.newHashMap();
Plan plan = joinToMultiJoin(filter, planToHintType);
Map<Plan, JoinHintTypeAndMarkJoinSlot> planToJoinMembers = Maps.newHashMap();
Plan plan = joinToMultiJoin(filter, planToJoinMembers);
Preconditions.checkState(plan instanceof MultiJoin);
MultiJoin multiJoin = (MultiJoin) plan;
ctx.statementContext.setMaxNArayInnerJoin(multiJoin.children().size());
Plan after = multiJoinToJoin(multiJoin, planToHintType);
Plan after = multiJoinToJoin(multiJoin, planToJoinMembers);
return after;
}).toRule(RuleType.REORDER_JOIN);
}
@ -95,7 +97,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
* {@link LogicalJoin} or {@link LogicalFilter}--{@link LogicalJoin}
* --> {@link MultiJoin}
*/
public Plan joinToMultiJoin(Plan plan, Map<Plan, JoinHintType> planToHintType) {
public Plan joinToMultiJoin(Plan plan, Map<Plan, JoinHintTypeAndMarkJoinSlot> planToJoinMembers) {
// subtree can't specify the end of Pattern. so end can be GroupPlan or Filter
if (nonJoinAndNonFilter(plan)
|| (plan instanceof LogicalFilter && nonJoinAndNonFilter(plan.child(0)))) {
@ -125,10 +127,12 @@ public class ReorderJoin extends OneRewriteRuleFactory {
}
// recursively convert children.
planToHintType.put(join.left(), join.getLeftHint());
Plan left = joinToMultiJoin(join.left(), planToHintType);
planToHintType.put(join.right(), join.getRightHint());
Plan right = joinToMultiJoin(join.right(), planToHintType);
planToJoinMembers.put(join.left(),
new JoinHintTypeAndMarkJoinSlot(join.getLeftHint(), join.getLeftMarkJoinSlotReference()));
Plan left = joinToMultiJoin(join.left(), planToJoinMembers);
planToJoinMembers.put(join.right(),
new JoinHintTypeAndMarkJoinSlot(join.getRightHint(), join.getMarkJoinSlotReference()));
Plan right = joinToMultiJoin(join.right(), planToJoinMembers);
boolean changeLeft = join.getJoinType().isRightJoin()
|| join.getJoinType().isFullOuterJoin();
@ -211,7 +215,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
* A B C D F ──► A B C │ D F ──► MJ(FOJ MJ(A,B,C) MJ(D,F))
* </pre>
*/
public Plan multiJoinToJoin(MultiJoin multiJoin, Map<Plan, JoinHintType> planToHintType) {
public Plan multiJoinToJoin(MultiJoin multiJoin, Map<Plan, JoinHintTypeAndMarkJoinSlot> planToJoinMembers) {
if (multiJoin.arity() == 1) {
return PlanUtils.filterOrSelf(ImmutableSet.copyOf(multiJoin.getJoinFilter()), multiJoin.child(0));
}
@ -221,7 +225,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
for (Plan child : multiJoin.children()) {
if (child instanceof MultiJoin) {
MultiJoin childMultiJoin = (MultiJoin) child;
builder.add(multiJoinToJoin(childMultiJoin, planToHintType));
builder.add(multiJoinToJoin(childMultiJoin, planToJoinMembers));
} else {
builder.add(child);
}
@ -239,6 +243,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
Map<Boolean, List<Expression>> split = multiJoin.getJoinFilter().stream()
.collect(Collectors.partitioningBy(expr ->
Utils.isIntersecting(rightOutputExprIdSet, expr.getInputSlotExprIds())
|| expr.anyMatch(MarkJoinSlotReference.class::isInstance)
));
remainingFilter = split.get(true);
List<Expression> pushedFilter = split.get(false);
@ -246,7 +251,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1),
pushedFilter,
JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION), planToHintType);
ExpressionUtils.EMPTY_CONDITION), planToJoinMembers);
} else if (multiJoinHandleChildren.getJoinType().isRightJoin()) {
left = multiJoinHandleChildren.child(0);
Set<ExprId> leftOutputExprIdSet = left.getOutputExprIdSet();
@ -260,13 +265,13 @@ public class ReorderJoin extends OneRewriteRuleFactory {
multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()),
pushedFilter,
JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION), planToHintType);
ExpressionUtils.EMPTY_CONDITION), planToJoinMembers);
} else {
remainingFilter = multiJoin.getJoinFilter();
Preconditions.checkState(multiJoinHandleChildren.arity() == 2);
List<Plan> children = multiJoinHandleChildren.children().stream().map(child -> {
if (child instanceof MultiJoin) {
return multiJoinToJoin((MultiJoin) child, planToHintType);
return multiJoinToJoin((MultiJoin) child, planToJoinMembers);
} else {
return child;
}
@ -278,7 +283,8 @@ public class ReorderJoin extends OneRewriteRuleFactory {
return PlanUtils.filterOrSelf(ImmutableSet.copyOf(remainingFilter), new LogicalJoin<>(
multiJoinHandleChildren.getJoinType(),
ExpressionUtils.EMPTY_CONDITION, multiJoinHandleChildren.getNotInnerJoinConditions(),
JoinHint.fromRightPlanHintType(planToHintType.getOrDefault(right, JoinHintType.NONE)),
JoinHint.fromRightPlanHintType(getJoinHintType(planToJoinMembers, right)),
getMarkJoinSlotReference(planToJoinMembers, right),
left, right));
}
@ -291,7 +297,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
while (usedPlansIndex.size() != multiJoinHandleChildren.children().size()) {
LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, multiJoinHandleChildren.children(),
joinFilter, usedPlansIndex, planToHintType);
joinFilter, usedPlansIndex, planToJoinMembers);
join.getHashJoinConjuncts().forEach(joinFilter::remove);
join.getOtherJoinConjuncts().forEach(joinFilter::remove);
@ -321,7 +327,8 @@ public class ReorderJoin extends OneRewriteRuleFactory {
* @return InnerJoin or CrossJoin{left, last of [candidates]}
*/
private LogicalJoin<? extends Plan, ? extends Plan> findInnerJoin(Plan left, List<Plan> candidates,
Set<Expression> joinFilter, Set<Integer> usedPlansIndex, Map<Plan, JoinHintType> planToHintType) {
Set<Expression> joinFilter, Set<Integer> usedPlansIndex,
Map<Plan, JoinHintTypeAndMarkJoinSlot> planToJoinMembers) {
List<Expression> otherJoinConditions = Lists.newArrayList();
Set<ExprId> leftOutputExprIdSet = left.getOutputExprIdSet();
for (int i = 0; i < candidates.size(); i++) {
@ -350,7 +357,8 @@ public class ReorderJoin extends OneRewriteRuleFactory {
usedPlansIndex.add(i);
return new LogicalJoin<>(JoinType.INNER_JOIN,
hashJoinConditions, otherJoinConditions,
JoinHint.fromRightPlanHintType(planToHintType.getOrDefault(candidate, JoinHintType.NONE)),
JoinHint.fromRightPlanHintType(getJoinHintType(planToJoinMembers, candidate)),
getMarkJoinSlotReference(planToJoinMembers, candidate),
left, candidate);
}
}
@ -365,7 +373,8 @@ public class ReorderJoin extends OneRewriteRuleFactory {
return new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION,
otherJoinConditions,
JoinHint.fromRightPlanHintType(planToHintType.getOrDefault(right, JoinHintType.NONE)),
JoinHint.fromRightPlanHintType(getJoinHintType(planToJoinMembers, right)),
getMarkJoinSlotReference(planToJoinMembers, right),
left, right);
}
@ -375,4 +384,33 @@ public class ReorderJoin extends OneRewriteRuleFactory {
private boolean nonJoinAndNonFilter(Plan plan) {
return !(plan instanceof LogicalJoin) && !(plan instanceof LogicalFilter);
}
private JoinHintType getJoinHintType(Map<Plan, JoinHintTypeAndMarkJoinSlot> planToJoinMembers, Plan plan) {
return planToJoinMembers.get(plan) == null ? JoinHintType.NONE : planToJoinMembers.get(plan).getJoinHintType();
}
private Optional<MarkJoinSlotReference> getMarkJoinSlotReference(
Map<Plan, JoinHintTypeAndMarkJoinSlot> planToJoinMembers, Plan plan) {
return planToJoinMembers.get(plan) == null
? Optional.empty() : planToJoinMembers.get(plan).getMarkJoinSlotReference();
}
private static class JoinHintTypeAndMarkJoinSlot {
private JoinHintType joinHintType;
private Optional<MarkJoinSlotReference> markJoinSlotReference;
public JoinHintTypeAndMarkJoinSlot(
JoinHintType joinHintType, Optional<MarkJoinSlotReference> markJoinSlotReference) {
this.joinHintType = joinHintType;
this.markJoinSlotReference = markJoinSlotReference;
}
public JoinHintType getJoinHintType() {
return joinHintType == null ? JoinHintType.NONE : joinHintType;
}
public Optional<MarkJoinSlotReference> getMarkJoinSlotReference() {
return markJoinSlotReference;
}
}
}

View File

@ -60,6 +60,12 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
AssertNumRowsElement.Assertion.EQ),
(LogicalPlan) apply.right());
return new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION,
apply.getSubCorrespondingConject().isPresent()
? ExpressionUtils.extractConjunction((Expression) apply.getSubCorrespondingConject().get())
: ExpressionUtils.EMPTY_CONDITION,
JoinHint.NONE,
apply.getMarkJoinSlotReference(),
(LogicalPlan) apply.left(), assertNumRows);
}
@ -73,14 +79,20 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
throw new AnalysisException(
"scalar subquery's correlatedPredicates's operator must be EQ");
});
} else {
throw new AnalysisException("correlationFilter can't be null in correlatedToJoin");
}
return new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN,
return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN,
ExpressionUtils.EMPTY_CONDITION,
correlationFilter
.map(ExpressionUtils::extractConjunction)
.orElse(ExpressionUtils.EMPTY_CONDITION),
ExpressionUtils.extractConjunction(
apply.getSubCorrespondingConject().isPresent()
? ExpressionUtils.and(
(Expression) apply.getSubCorrespondingConject().get(),
correlationFilter.get())
: correlationFilter.get()),
JoinHint.NONE,
apply.getMarkJoinSlotReference(),
(LogicalPlan) apply.left(),
(LogicalPlan) apply.right());
}

View File

@ -88,6 +88,8 @@ public class UnCorrelatedApplyAggregateFilter extends OneRewriteRuleFactory {
return new LogicalApply<>(apply.getCorrelationSlot(),
apply.getSubqueryExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate),
apply.getMarkJoinSlotReference(),
apply.getSubCorrespondingConject(),
apply.left(), newAgg);
}).toRule(RuleType.UN_CORRELATED_APPLY_AGGREGATE_FILTER);
}

View File

@ -68,7 +68,8 @@ public class UnCorrelatedApplyFilter extends OneRewriteRuleFactory {
Plan child = PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate), filter.child());
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate),
ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(),
apply.getSubCorrespondingConject(),
apply.left(), child);
}).toRule(RuleType.UN_CORRELATED_APPLY_FILTER);
}

View File

@ -89,7 +89,8 @@ public class UnCorrelatedApplyProjectFilter extends OneRewriteRuleFactory {
.forEach(projects::add);
LogicalProject newProject = new LogicalProject(projects, child);
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate),
ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(),
apply.getSubCorrespondingConject(),
apply.left(), newProject);
}).toRule(RuleType.UN_CORRELATED_APPLY_PROJECT_FILTER);
}

View File

@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.Subtract;
@ -293,4 +294,10 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta
builder.setSelectivity(1.0);
return builder.build();
}
@Override
public ColumnStatistic visitMarkJoinReference(
MarkJoinSlotReference markJoinSlotReference, StatsDeriveResult context) {
return ColumnStatistic.DEFAULT;
}
}

View File

@ -75,9 +75,9 @@ public class AssertNumRowsElement extends Expression implements LeafExpression,
@Override
public String toString() {
return Utils.toSqlString("AssertNumRowsElement",
"desiredNumOfRows: ",
"desiredNumOfRows",
Long.toString(desiredNumOfRows),
"assertion: ", assertion);
"assertion", assertion);
}
@Override

View File

@ -0,0 +1,64 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BooleanType;
/**
* A special type of column that will be generated to replace the subquery when unnesting the subquery of MarkJoin.
*/
public class MarkJoinSlotReference extends SlotReference implements SlotNotFromChildren {
final boolean existsHasAgg;
public MarkJoinSlotReference(String name) {
super(name, BooleanType.INSTANCE, false);
this.existsHasAgg = false;
}
public MarkJoinSlotReference(String name, boolean existsHasAgg) {
super(name, BooleanType.INSTANCE, false);
this.existsHasAgg = existsHasAgg;
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMarkJoinReference(this, context);
}
@Override
public String toString() {
return super.toString() + "#" + existsHasAgg;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
MarkJoinSlotReference that = (MarkJoinSlotReference) o;
return this.existsHasAgg == that.existsHasAgg && super.equals(that);
}
public boolean isExistsHasAgg() {
return existsHasAgg;
}
}

View File

@ -0,0 +1,26 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
/**
* CheckAfterRewrite will check whether the slots used by the current node can be obtained in the child's output,
* but there are special slots that are actively generated by the current node,
* so such slots need to be skipped during the check.
*/
public interface SlotNotFromChildren {
}

View File

@ -32,7 +32,7 @@ import java.util.Optional;
/**
* it is not a real column exist in table.
*/
public class VirtualSlotReference extends SlotReference {
public class VirtualSlotReference extends SlotReference implements SlotNotFromChildren {
// arguments of GroupingScalarFunction
private final List<Expression> realExpressions;

View File

@ -51,6 +51,7 @@ import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.ListQuery;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Mod;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -199,6 +200,10 @@ public abstract class ExpressionVisitor<R, C>
return visitSlot(slotReference, context);
}
public R visitMarkJoinReference(MarkJoinSlotReference markJoinSlotReference, C context) {
return visitSlotReference(markJoinSlotReference, context);
}
public R visitLiteral(Literal literal, C context) {
return visit(literal, context);
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.plans.algebra;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinHint.JoinHintType;
import org.apache.doris.nereids.trees.plans.JoinType;
@ -39,6 +40,8 @@ public interface Join {
JoinHint getHint();
boolean isMarkJoin();
default boolean hasJoinHint() {
return getHint() != JoinHint.NONE;
}
@ -67,4 +70,8 @@ public interface Join {
return JoinHintType.NONE;
}
}
default Optional<MarkJoinSlotReference> getLeftMarkJoinSlotReference() {
return Optional.empty();
}
}

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
@ -51,6 +52,10 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
private final SubqueryExpr subqueryExpr;
// correlation Conjunction
private final Optional<Expression> correlationFilter;
// The slot replaced by the subquery in MarkJoin
private final Optional<MarkJoinSlotReference> markJoinSlotReference;
private final Optional<Expression> subCorrespondingConject;
/**
* Constructor.
@ -59,18 +64,23 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
Optional<LogicalProperties> logicalProperties,
List<Expression> correlationSlot,
SubqueryExpr subqueryExpr, Optional<Expression> correlationFilter,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<Expression> subCorrespondingConject,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
super(PlanType.LOGICAL_APPLY, groupExpression, logicalProperties, leftChild, rightChild);
this.correlationSlot = correlationSlot == null ? ImmutableList.of() : ImmutableList.copyOf(correlationSlot);
this.subqueryExpr = Objects.requireNonNull(subqueryExpr, "subquery can not be null");
this.correlationFilter = correlationFilter;
this.markJoinSlotReference = markJoinSlotReference;
this.subCorrespondingConject = subCorrespondingConject;
}
public LogicalApply(List<Expression> correlationSlot, SubqueryExpr subqueryExpr,
Optional<Expression> correlationFilter,
Optional<Expression> correlationFilter, Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<Expression> subCorrespondingConject,
LEFT_CHILD_TYPE input, RIGHT_CHILD_TYPE subquery) {
this(Optional.empty(), Optional.empty(), correlationSlot, subqueryExpr, correlationFilter,
input, subquery);
this(Optional.empty(), Optional.empty(), correlationSlot, subqueryExpr,
correlationFilter, markJoinSlotReference, subCorrespondingConject, input, subquery);
}
public List<Expression> getCorrelationSlot() {
@ -105,6 +115,18 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return correlationFilter.isPresent();
}
public boolean isMarkJoin() {
return markJoinSlotReference.isPresent();
}
public Optional<MarkJoinSlotReference> getMarkJoinSlotReference() {
return markJoinSlotReference;
}
public Optional<Expression> getSubCorrespondingConject() {
return subCorrespondingConject;
}
@Override
public List<Slot> computeOutput() {
return ImmutableList.<Slot>builder()
@ -116,7 +138,11 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
public String toString() {
return Utils.toSqlString("LogicalApply",
"correlationSlot", correlationSlot,
"correlationFilter", correlationFilter);
"correlationFilter", correlationFilter,
"isMarkJoin", markJoinSlotReference.isPresent(),
"MarkJoinSlotReference", markJoinSlotReference.isPresent() ? markJoinSlotReference.get() : "empty",
"scalarSubCorrespondingSlot",
subCorrespondingConject.isPresent() ? subCorrespondingConject.get() : "empty");
}
@Override
@ -130,13 +156,15 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
LogicalApply that = (LogicalApply) o;
return Objects.equals(correlationSlot, that.getCorrelationSlot())
&& Objects.equals(subqueryExpr, that.getSubqueryExpr())
&& Objects.equals(correlationFilter, that.getCorrelationFilter());
&& Objects.equals(correlationFilter, that.getCorrelationFilter())
&& Objects.equals(markJoinSlotReference, that.getMarkJoinSlotReference())
&& Objects.equals(subCorrespondingConject, that.getSubCorrespondingConject());
}
@Override
public int hashCode() {
return Objects.hash(
correlationSlot, subqueryExpr, correlationFilter);
correlationSlot, subqueryExpr, correlationFilter, markJoinSlotReference, subCorrespondingConject);
}
@Override
@ -161,18 +189,21 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
public LogicalBinary<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new LogicalApply<>(correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, subCorrespondingConject,
children.get(0), children.get(1));
}
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalApply<>(groupExpression, Optional.of(getLogicalProperties()),
correlationSlot, subqueryExpr, correlationFilter, left(), right());
correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, subCorrespondingConject, left(), right());
}
@Override
public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new LogicalApply<>(Optional.empty(), logicalProperties,
correlationSlot, subqueryExpr, correlationFilter, left(), right());
correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, subCorrespondingConject, left(), right());
}
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
@ -55,31 +56,47 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
private final List<Expression> hashJoinConjuncts;
private final JoinHint hint;
// When the predicate condition contains subqueries and disjunctions, the join will be marked as MarkJoin.
private final Optional<MarkJoinSlotReference> markJoinSlotReference;
// Use for top-to-down join reorder
private final JoinReorderContext joinReorderContext = new JoinReorderContext();
public LogicalJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
this(joinType, ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE,
Optional.empty(), Optional.empty(), leftChild, rightChild);
Optional.empty(), Optional.empty(), Optional.empty(), leftChild, rightChild);
}
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, Optional.empty(),
Optional.empty(), leftChild, rightChild);
Optional.empty(), Optional.empty(), leftChild, rightChild);
}
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts,
JoinHint hint, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, Optional.empty(), Optional.empty(), leftChild,
rightChild);
this(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, Optional.empty(), Optional.empty(),
Optional.empty(), leftChild, rightChild);
}
public LogicalJoin(
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
JoinHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts,
otherJoinConjuncts, hint, markJoinSlotReference,
Optional.empty(), Optional.empty(), leftChild, rightChild);
}
/**
* Just use in withXXX method.
*/
private LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts,
JoinHint hint, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild,
JoinHint hint, Optional<MarkJoinSlotReference> markJoinSlotReference,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild,
JoinReorderContext joinReorderContext) {
super(PlanType.LOGICAL_JOIN, Optional.empty(), Optional.empty(), leftChild, rightChild);
this.joinType = Objects.requireNonNull(joinType, "joinType can not be null");
@ -87,6 +104,7 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
this.otherJoinConjuncts = ImmutableList.copyOf(otherJoinConjuncts);
this.hint = Objects.requireNonNull(hint, "hint can not be null");
this.joinReorderContext.copyFrom(joinReorderContext);
this.markJoinSlotReference = markJoinSlotReference;
}
private LogicalJoin(
@ -94,6 +112,7 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
JoinHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
LEFT_CHILD_TYPE leftChild,
@ -103,6 +122,7 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.otherJoinConjuncts = ImmutableList.copyOf(otherJoinConjuncts);
this.hint = Objects.requireNonNull(hint, "hint can not be null");
this.markJoinSlotReference = markJoinSlotReference;
}
public List<Expression> getOtherJoinConjuncts() {
@ -131,6 +151,10 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return hint;
}
public boolean isMarkJoin() {
return markJoinSlotReference.isPresent();
}
public JoinReorderContext getJoinReorderContext() {
return joinReorderContext;
}
@ -140,10 +164,16 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return JoinUtils.getJoinOutput(joinType, left(), right());
}
@Override
public List<Slot> computeNonUserVisibleOutput() {
return markJoinSlotReference.<ImmutableList<Slot>>map(ImmutableList::of).orElseGet(ImmutableList::of);
}
@Override
public String toString() {
List<Object> args = Lists.newArrayList(
"type", joinType,
"markJoinSlotReference", markJoinSlotReference,
"hashJoinConjuncts", hashJoinConjuncts,
"otherJoinConjuncts", otherJoinConjuncts);
if (hint != JoinHint.NONE) {
@ -168,12 +198,13 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return joinType == that.joinType
&& hint == that.hint
&& hashJoinConjuncts.equals(that.hashJoinConjuncts)
&& otherJoinConjuncts.equals(that.otherJoinConjuncts);
&& otherJoinConjuncts.equals(that.otherJoinConjuncts)
&& Objects.equals(markJoinSlotReference, that.markJoinSlotReference);
}
@Override
public int hashCode() {
return Objects.hash(joinType, hashJoinConjuncts, otherJoinConjuncts);
return Objects.hash(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference);
}
@Override
@ -189,6 +220,10 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
.build();
}
public Optional<MarkJoinSlotReference> getMarkJoinSlotReference() {
return markJoinSlotReference;
}
@Override
public LEFT_CHILD_TYPE left() {
return (LEFT_CHILD_TYPE) child(0);
@ -202,13 +237,15 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
@Override
public LogicalJoin<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, children.get(0),
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
markJoinSlotReference, children.get(0),
children.get(1), joinReorderContext);
}
@Override
public LogicalJoin<Plan, Plan> withGroupExpression(Optional<GroupExpression> groupExpression) {
LogicalJoin<Plan, Plan> newJoin = new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
markJoinSlotReference,
groupExpression, Optional.of(getLogicalProperties()), left(), right());
newJoin.getJoinReorderContext().copyFrom(this.getJoinReorderContext());
return newJoin;
@ -217,42 +254,46 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
@Override
public LogicalJoin<Plan, Plan> withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
LogicalJoin<Plan, Plan> newJoin = new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
markJoinSlotReference,
Optional.empty(), logicalProperties, left(), right());
newJoin.getJoinReorderContext().copyFrom(this.getJoinReorderContext());
return newJoin;
}
public LogicalJoin<Plan, Plan> withHashJoinConjuncts(List<Expression> hashJoinConjuncts) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, this.otherJoinConjuncts, hint, left(), right(),
joinReorderContext);
return new LogicalJoin<>(joinType, hashJoinConjuncts, this.otherJoinConjuncts, hint, markJoinSlotReference,
left(), right(), joinReorderContext);
}
public LogicalJoin<Plan, Plan> withJoinConjuncts(
List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts,
hint, left(), right(), joinReorderContext);
hint, markJoinSlotReference, left(), right(), joinReorderContext);
}
public LogicalJoin<Plan, Plan> withHashJoinConjunctsAndChildren(
List<Expression> hashJoinConjuncts, List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, children.get(0),
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
markJoinSlotReference, children.get(0),
children.get(1), joinReorderContext);
}
public LogicalJoin<Plan, Plan> withConjunctsChildren(List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts, Plan left, Plan right) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, left,
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference, left,
right, joinReorderContext);
}
public LogicalJoin<Plan, Plan> withJoinType(JoinType joinType) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, left(), right(),
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
markJoinSlotReference, left(), right(),
joinReorderContext);
}
public LogicalJoin<Plan, Plan> withOtherJoinConjuncts(List<Expression> otherJoinConjuncts) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, left(), right(),
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
markJoinSlotReference, left(), right(),
joinReorderContext);
}
}

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
@ -46,23 +47,28 @@ public class UsingJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Pl
private final ImmutableList<Expression> otherJoinConjuncts;
private final ImmutableList<Expression> hashJoinConjuncts;
private final JoinHint hint;
private final Optional<MarkJoinSlotReference> markJoinSlotReference;
public UsingJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild,
List<Expression> expressions, List<Expression> hashJoinConjuncts,
JoinHint hint) {
this(joinType, leftChild, rightChild, expressions,
hashJoinConjuncts, Optional.empty(), Optional.empty(), hint);
hashJoinConjuncts, Optional.empty(), Optional.empty(), hint, Optional.empty());
}
/**
* Constructor.
*/
public UsingJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild,
List<Expression> expressions, List<Expression> hashJoinConjuncts, Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
JoinHint hint) {
JoinHint hint, Optional<MarkJoinSlotReference> markJoinSlotReference) {
super(PlanType.LOGICAL_USING_JOIN, groupExpression, logicalProperties, leftChild, rightChild);
this.joinType = joinType;
this.otherJoinConjuncts = ImmutableList.copyOf(expressions);
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.hint = hint;
this.markJoinSlotReference = markJoinSlotReference;
}
@Override
@ -107,19 +113,19 @@ public class UsingJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Pl
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new UsingJoin(joinType, child(0), child(1), otherJoinConjuncts,
hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint);
hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint, markJoinSlotReference);
}
@Override
public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new UsingJoin(joinType, child(0), child(1), otherJoinConjuncts,
hashJoinConjuncts, groupExpression, logicalProperties, hint);
hashJoinConjuncts, groupExpression, logicalProperties, hint, markJoinSlotReference);
}
@Override
public Plan withChildren(List<Plan> children) {
return new UsingJoin(joinType, children.get(0), children.get(1), otherJoinConjuncts,
hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint);
hashJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), hint, markJoinSlotReference);
}
@Override
@ -151,6 +157,14 @@ public class UsingJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Pl
return hint;
}
public boolean isMarkJoin() {
return markJoinSlotReference.isPresent();
}
public Optional<MarkJoinSlotReference> getMarkJoinSlotReference() {
return markJoinSlotReference;
}
@Override
public Optional<Expression> getOnClauseCondition() {
return ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts);

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -50,6 +51,7 @@ public abstract class AbstractPhysicalJoin<
protected final List<Expression> hashJoinConjuncts;
protected final List<Expression> otherJoinConjuncts;
protected final JoinHint hint;
protected final Optional<MarkJoinSlotReference> markJoinSlotReference;
// use for translate only
protected final List<Expression> filterConjuncts = Lists.newArrayList();
@ -64,6 +66,7 @@ public abstract class AbstractPhysicalJoin<
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
JoinHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
super(type, groupExpression, logicalProperties, leftChild, rightChild);
@ -71,6 +74,7 @@ public abstract class AbstractPhysicalJoin<
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.otherJoinConjuncts = ImmutableList.copyOf(otherJoinConjuncts);
this.hint = Objects.requireNonNull(hint, "hint can not be null");
this.markJoinSlotReference = markJoinSlotReference;
}
/**
@ -82,6 +86,7 @@ public abstract class AbstractPhysicalJoin<
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
JoinHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties,
PhysicalProperties physicalProperties,
@ -93,6 +98,7 @@ public abstract class AbstractPhysicalJoin<
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.otherJoinConjuncts = ImmutableList.copyOf(otherJoinConjuncts);
this.hint = hint;
this.markJoinSlotReference = markJoinSlotReference;
}
public List<Expression> getHashJoinConjuncts() {
@ -115,6 +121,10 @@ public abstract class AbstractPhysicalJoin<
return otherJoinConjuncts;
}
public boolean isMarkJoin() {
return markJoinSlotReference.isPresent();
}
@Override
public List<? extends Expression> getExpressions() {
return new Builder<Expression>()
@ -139,12 +149,13 @@ public abstract class AbstractPhysicalJoin<
return joinType == that.joinType
&& hashJoinConjuncts.equals(that.hashJoinConjuncts)
&& otherJoinConjuncts.equals(that.otherJoinConjuncts)
&& hint.equals(that.hint);
&& hint.equals(that.hint)
&& Objects.equals(markJoinSlotReference, that.markJoinSlotReference);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), joinType, hashJoinConjuncts, otherJoinConjuncts);
return Objects.hash(super.hashCode(), joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference);
}
/**
@ -165,6 +176,10 @@ public abstract class AbstractPhysicalJoin<
return filterConjuncts;
}
public Optional<MarkJoinSlotReference> getMarkJoinSlotReference() {
return markJoinSlotReference;
}
public void addFilterConjuncts(Collection<Expression> conjuncts) {
filterConjuncts.addAll(conjuncts);
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -48,11 +49,12 @@ public class PhysicalHashJoin<
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
JoinHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
LogicalProperties logicalProperties,
LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, Optional.empty(), logicalProperties, leftChild,
rightChild);
this(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
Optional.empty(), logicalProperties, leftChild, rightChild);
}
/**
@ -66,10 +68,11 @@ public class PhysicalHashJoin<
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
JoinHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
super(PlanType.PHYSICAL_HASH_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
super(PlanType.PHYSICAL_HASH_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
groupExpression, logicalProperties, leftChild, rightChild);
}
@ -84,13 +87,14 @@ public class PhysicalHashJoin<
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
JoinHint hint,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties,
PhysicalProperties physicalProperties,
StatsDeriveResult statsDeriveResult,
LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
super(PlanType.PHYSICAL_HASH_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
super(PlanType.PHYSICAL_HASH_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
groupExpression, logicalProperties, physicalProperties, statsDeriveResult, leftChild, rightChild);
}
@ -104,6 +108,8 @@ public class PhysicalHashJoin<
List<Object> args = Lists.newArrayList("type", joinType,
"hashJoinCondition", hashJoinConjuncts,
"otherJoinCondition", otherJoinConjuncts,
"isMarkJoin", markJoinSlotReference.isPresent(),
"MarkJoinSlotReference", markJoinSlotReference.isPresent() ? markJoinSlotReference.get() : "empty",
"stats", statsDeriveResult);
if (hint != JoinHint.NONE) {
args.add("hint");
@ -115,28 +121,28 @@ public class PhysicalHashJoin<
@Override
public PhysicalHashJoin<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
getLogicalProperties(), children.get(0), children.get(1));
}
@Override
public PhysicalHashJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withGroupExpression(
Optional<GroupExpression> groupExpression) {
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
groupExpression, getLogicalProperties(), left(), right());
}
@Override
public PhysicalHashJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withLogicalProperties(
Optional<LogicalProperties> logicalProperties) {
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
Optional.empty(), logicalProperties.get(), left(), right());
}
@Override
public PhysicalHashJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withPhysicalPropertiesAndStats(
PhysicalProperties physicalProperties, StatsDeriveResult statsDeriveResult) {
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint,
return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, markJoinSlotReference,
Optional.empty(), getLogicalProperties(), physicalProperties, statsDeriveResult, left(), right());
}
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -46,11 +47,12 @@ public class PhysicalNestedLoopJoin<
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
Optional<MarkJoinSlotReference> markJoinSlotReference,
LogicalProperties logicalProperties,
LEFT_CHILD_TYPE leftChild,
RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, Optional.empty(), logicalProperties, leftChild,
rightChild);
this(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference,
Optional.empty(), logicalProperties, leftChild, rightChild);
}
/**
@ -63,12 +65,13 @@ public class PhysicalNestedLoopJoin<
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
super(PlanType.PHYSICAL_NESTED_LOOP_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts,
// nested loop join ignores join hints.
JoinHint.NONE,
JoinHint.NONE, markJoinSlotReference,
groupExpression, logicalProperties, leftChild, rightChild);
}
@ -82,6 +85,7 @@ public class PhysicalNestedLoopJoin<
JoinType joinType,
List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjuncts,
Optional<MarkJoinSlotReference> markJoinSlotReference,
Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties,
PhysicalProperties physicalProperties,
@ -90,7 +94,7 @@ public class PhysicalNestedLoopJoin<
RIGHT_CHILD_TYPE rightChild) {
super(PlanType.PHYSICAL_NESTED_LOOP_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts,
// nested loop join ignores join hints.
JoinHint.NONE,
JoinHint.NONE, markJoinSlotReference,
groupExpression, logicalProperties, physicalProperties, statsDeriveResult, leftChild, rightChild);
}
@ -104,7 +108,9 @@ public class PhysicalNestedLoopJoin<
// TODO: Maybe we could pull up this to the abstract class in the future.
return Utils.toSqlString("PhysicalNestedLoopJoin",
"type", joinType,
"otherJoinCondition", otherJoinConjuncts
"otherJoinCondition", otherJoinConjuncts,
"isMarkJoin", markJoinSlotReference.isPresent(),
"markJoinSlotReference", markJoinSlotReference.isPresent() ? markJoinSlotReference.get() : "empty"
);
}
@ -112,21 +118,23 @@ public class PhysicalNestedLoopJoin<
public PhysicalNestedLoopJoin<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new PhysicalNestedLoopJoin<>(joinType,
hashJoinConjuncts, otherJoinConjuncts, getLogicalProperties(), children.get(0), children.get(1));
hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference,
getLogicalProperties(), children.get(0), children.get(1));
}
@Override
public PhysicalNestedLoopJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withGroupExpression(
Optional<GroupExpression> groupExpression) {
return new PhysicalNestedLoopJoin<>(joinType,
hashJoinConjuncts, otherJoinConjuncts, groupExpression, getLogicalProperties(), left(), right());
hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference,
groupExpression, getLogicalProperties(), left(), right());
}
@Override
public PhysicalNestedLoopJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withLogicalProperties(
Optional<LogicalProperties> logicalProperties) {
return new PhysicalNestedLoopJoin<>(joinType,
hashJoinConjuncts, otherJoinConjuncts, Optional.empty(),
hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference, Optional.empty(),
logicalProperties.get(), left(), right());
}
@ -134,7 +142,7 @@ public class PhysicalNestedLoopJoin<
public PhysicalNestedLoopJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withPhysicalPropertiesAndStats(
PhysicalProperties physicalProperties, StatsDeriveResult statsDeriveResult) {
return new PhysicalNestedLoopJoin<>(joinType,
hashJoinConjuncts, otherJoinConjuncts, Optional.empty(),
hashJoinConjuncts, otherJoinConjuncts, markJoinSlotReference, Optional.empty(),
getLogicalProperties(), physicalProperties, statsDeriveResult, left(), right());
}
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.SlotReference;
@ -208,10 +209,14 @@ public class Utils {
if (binaryExpression.left().anyMatch(correlatedSlots::contains)) {
if (binaryExpression.right() instanceof SlotReference) {
slots.add(binaryExpression.right());
} else if (binaryExpression.right() instanceof Cast) {
slots.add(((Cast) binaryExpression.right()).child());
}
} else {
if (binaryExpression.left() instanceof SlotReference) {
slots.add(binaryExpression.left());
} else if (binaryExpression.left() instanceof Cast) {
slots.add(((Cast) binaryExpression.left()).child());
}
}
return slots;

View File

@ -127,8 +127,8 @@ public class HashJoinNode extends JoinNodeBase {
*/
public HashJoinNode(PlanNodeId id, PlanNode outer, PlanNode inner, JoinOperator joinOp,
List<Expr> eqJoinConjuncts, List<Expr> otherJoinConjuncts, List<Expr> srcToOutputList,
TupleDescriptor intermediateTuple, TupleDescriptor outputTuple) {
super(id, "HASH JOIN", StatisticalType.HASH_JOIN_NODE, joinOp);
TupleDescriptor intermediateTuple, TupleDescriptor outputTuple, boolean isMarkJoin) {
super(id, "HASH JOIN", StatisticalType.HASH_JOIN_NODE, joinOp, isMarkJoin);
Preconditions.checkArgument(eqJoinConjuncts != null && !eqJoinConjuncts.isEmpty());
Preconditions.checkArgument(otherJoinConjuncts != null);
tblRefIds.addAll(outer.getTblRefIds());
@ -758,7 +758,6 @@ public class HashJoinNode extends JoinNodeBase {
StringBuilder output =
new StringBuilder().append(detailPrefix).append("join op: ").append(joinOp.toString()).append("(")
.append(distrModeStr).append(")").append("[").append(colocateReason).append("]\n");
output.append(detailPrefix).append("is mark: ").append(isMarkJoin()).append("\n");
if (detailLevel == TExplainLevel.BRIEF) {
output.append(detailPrefix).append(
String.format("cardinality=%,d", cardinality)).append("\n");
@ -809,6 +808,9 @@ public class HashJoinNode extends JoinNodeBase {
}
output.append("\n");
}
if (detailLevel == TExplainLevel.VERBOSE) {
output.append(detailPrefix).append("isMarkJoin: ").append(isMarkJoin()).append("\n");
}
return output.toString();
}

View File

@ -58,6 +58,7 @@ public abstract class JoinNodeBase extends PlanNode {
protected final TableRef innerRef;
protected final JoinOperator joinOp;
protected final boolean isMark;
protected TupleDescriptor vOutputTupleDesc;
protected ExprSubstitutionMap vSrcToOutputSMap;
protected List<TupleDescriptor> vIntermediateTupleDescList;
@ -85,10 +86,11 @@ public abstract class JoinNodeBase extends PlanNode {
} else if (joinOp.equals(JoinOperator.RIGHT_OUTER_JOIN)) {
nullableTupleIds.addAll(outer.getTupleIds());
}
this.isMark = this.innerRef != null && innerRef.isMark();
}
public boolean isMarkJoin() {
return innerRef != null && innerRef.isMark();
return isMark;
}
public JoinOperator getJoinOp() {
@ -474,10 +476,12 @@ public abstract class JoinNodeBase extends PlanNode {
/**
* Only for Nereids.
*/
public JoinNodeBase(PlanNodeId id, String planNodeName, StatisticalType statisticalType, JoinOperator joinOp) {
public JoinNodeBase(PlanNodeId id, String planNodeName,
StatisticalType statisticalType, JoinOperator joinOp, boolean isMark) {
super(id, planNodeName, statisticalType);
this.innerRef = null;
this.joinOp = joinOp;
this.isMark = isMark;
}
public TableRef getInnerRef() {

View File

@ -92,8 +92,8 @@ public class NestedLoopJoinNode extends JoinNodeBase {
*/
public NestedLoopJoinNode(PlanNodeId id, PlanNode outer, PlanNode inner, List<TupleId> tupleIds,
JoinOperator joinOperator, List<Expr> srcToOutputList, TupleDescriptor intermediateTuple,
TupleDescriptor outputTuple) {
super(id, "NESTED LOOP JOIN", StatisticalType.NESTED_LOOP_JOIN_NODE, joinOperator);
TupleDescriptor outputTuple, boolean isMarkJoin) {
super(id, "NESTED LOOP JOIN", StatisticalType.NESTED_LOOP_JOIN_NODE, joinOperator, isMarkJoin);
this.tupleIds.addAll(tupleIds);
children.add(outer);
children.add(inner);
@ -228,7 +228,6 @@ public class NestedLoopJoinNode extends JoinNodeBase {
StringBuilder output =
new StringBuilder().append(detailPrefix).append("join op: ").append(joinOp.toString()).append("(")
.append(distrModeStr).append(")\n");
output.append(detailPrefix).append("is mark: ").append(isMarkJoin()).append("\n");
if (detailLevel == TExplainLevel.BRIEF) {
output.append(detailPrefix).append(
@ -267,6 +266,9 @@ public class NestedLoopJoinNode extends JoinNodeBase {
}
output.append("\n");
}
if (detailLevel == TExplainLevel.VERBOSE) {
output.append(detailPrefix).append("isMarkJoin: ").append(isMarkJoin()).append("\n");
}
return output.toString();
}
}

View File

@ -66,6 +66,7 @@ import org.junit.jupiter.api.Test;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
@SuppressWarnings("unused")
@ -114,7 +115,7 @@ public class ChildOutputPropertyDeriverTest {
};
PhysicalHashJoin<GroupPlan, GroupPlan> join = new PhysicalHashJoin<>(JoinType.RIGHT_OUTER_JOIN,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, logicalProperties,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, Optional.empty(), logicalProperties,
groupPlan, groupPlan);
GroupExpression groupExpression = new GroupExpression(join);
@ -164,7 +165,7 @@ public class ChildOutputPropertyDeriverTest {
new SlotReference(new ExprId(0), "left", IntegerType.INSTANCE, false, Collections.emptyList()),
new SlotReference(new ExprId(2), "right", IntegerType.INSTANCE, false,
Collections.emptyList()))),
ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, logicalProperties, groupPlan, groupPlan);
ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, Optional.empty(), logicalProperties, groupPlan, groupPlan);
GroupExpression groupExpression = new GroupExpression(join);
Map<ExprId, Integer> leftMap = Maps.newHashMap();
@ -210,7 +211,7 @@ public class ChildOutputPropertyDeriverTest {
new SlotReference(new ExprId(0), "left", IntegerType.INSTANCE, false, Collections.emptyList()),
new SlotReference(new ExprId(2), "right", IntegerType.INSTANCE, false,
Collections.emptyList()))),
ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, logicalProperties, groupPlan, groupPlan);
ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, Optional.empty(), logicalProperties, groupPlan, groupPlan);
GroupExpression groupExpression = new GroupExpression(join);
Map<ExprId, Integer> leftMap = Maps.newHashMap();
@ -247,7 +248,7 @@ public class ChildOutputPropertyDeriverTest {
@Test
public void testNestedLoopJoin() {
PhysicalNestedLoopJoin<GroupPlan, GroupPlan> join = new PhysicalNestedLoopJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, logicalProperties, groupPlan,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, Optional.empty(), logicalProperties, groupPlan,
groupPlan);
GroupExpression groupExpression = new GroupExpression(join);

View File

@ -52,6 +52,7 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
@SuppressWarnings("unused")
public class RequestPropertyDeriverTest {
@ -79,7 +80,7 @@ public class RequestPropertyDeriverTest {
@Test
public void testNestedLoopJoin() {
PhysicalNestedLoopJoin<GroupPlan, GroupPlan> join = new PhysicalNestedLoopJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, logicalProperties, groupPlan,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, Optional.empty(), logicalProperties, groupPlan,
groupPlan);
GroupExpression groupExpression = new GroupExpression(join);
@ -103,7 +104,7 @@ public class RequestPropertyDeriverTest {
};
PhysicalHashJoin<GroupPlan, GroupPlan> join = new PhysicalHashJoin<>(JoinType.RIGHT_OUTER_JOIN,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, logicalProperties,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, Optional.empty(), logicalProperties,
groupPlan, groupPlan);
GroupExpression groupExpression = new GroupExpression(join);
@ -130,7 +131,7 @@ public class RequestPropertyDeriverTest {
};
PhysicalHashJoin<GroupPlan, GroupPlan> join = new PhysicalHashJoin<>(JoinType.INNER_JOIN,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, logicalProperties,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, Optional.empty(), logicalProperties,
groupPlan, groupPlan);
GroupExpression groupExpression = new GroupExpression(join);

View File

@ -38,6 +38,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.UnCorrelatedApplyProjectFi
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
@ -219,9 +220,14 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
logicalJoin(
any(),
logicalAggregate()
).when(FieldChecker.check("joinType", JoinType.LEFT_OUTER_JOIN))
).when(FieldChecker.check("joinType", JoinType.LEFT_SEMI_JOIN))
.when(FieldChecker.check("otherJoinConjuncts",
ImmutableList.of(new EqualTo(
new SlotReference(new ExprId(0), "k1", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t6")),
new SlotReference(new ExprId(7), "sum(k3)", BigIntType.INSTANCE, true,
ImmutableList.of())
), new EqualTo(
new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t7")),
new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true,
@ -473,19 +479,23 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
.applyBottomUp(new UnCorrelatedApplyAggregateFilter())
.applyBottomUp(new ScalarApplyToJoin())
.matches(
logicalJoin(
any(),
logicalAggregate(
logicalProject()
)
)
.when(j -> j.getJoinType().equals(JoinType.LEFT_OUTER_JOIN))
.when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(
new EqualTo(new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t6")),
new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t7")))
)))
logicalJoin(
any(),
logicalAggregate(
logicalProject()
)
)
.when(j -> j.getJoinType().equals(JoinType.LEFT_SEMI_JOIN))
.when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of(
new LessThan(new SlotReference(new ExprId(0), "k1", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t6")),
new SlotReference(new ExprId(8), "max(aa)", BigIntType.INSTANCE, true,
ImmutableList.of())),
new EqualTo(new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t6")),
new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t7")))
)))
);
}
}

View File

@ -41,6 +41,7 @@ import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/**
* initial plan:
@ -75,7 +76,7 @@ class FindHashConditionForJoinTest {
Expression less = new LessThan(scoreId, studentId);
List<Expression> expr = ImmutableList.of(eq1, eq2, eq3, or, less);
LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, new ArrayList<>(),
expr, JoinHint.NONE, student, score);
expr, JoinHint.NONE, Optional.empty(), student, score);
CascadesContext context = MemoTestUtils.createCascadesContext(join);
List<Rule> rules = Lists.newArrayList(new FindHashConditionForJoin().build());

View File

@ -43,6 +43,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.List;
import java.util.Optional;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class PushdownJoinOtherConditionTest {
@ -85,7 +86,7 @@ public class PushdownJoinOtherConditionTest {
right = rStudent;
}
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, left, right);
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), left, right);
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
Memo memo = rewrite(root);
@ -125,7 +126,7 @@ public class PushdownJoinOtherConditionTest {
Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
List<Expression> condition = ImmutableList.of(leftSide, rightSide);
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, rStudent,
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), rStudent,
rScore);
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
@ -168,7 +169,7 @@ public class PushdownJoinOtherConditionTest {
right = rStudent;
}
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, left, right);
Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, JoinHint.NONE, Optional.empty(), left, right);
Plan root = new LogicalProject<>(Lists.newArrayList(), join);
Memo memo = rewrite(root);

View File

@ -0,0 +1,249 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
import org.junit.jupiter.api.Test;
public class MarkJoinTest extends TestWithFeService {
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
useDatabase("test");
createTable("CREATE TABLE `test_sq_dj1` (\n"
+ " `c1` bigint(20) NULL,\n"
+ " `c2` bigint(20) NULL,\n"
+ " `c3` bigint(20) not NULL,\n"
+ " `k4` bigint(20) not NULL,\n"
+ " `k5` bigint(20) NULL\n"
+ ") ENGINE=OLAP\n"
+ "DUPLICATE KEY(`c1`)\n"
+ "COMMENT 'OLAP'\n"
+ "DISTRIBUTED BY HASH(`c2`) BUCKETS 1\n"
+ "PROPERTIES (\n"
+ "\"replication_allocation\" = \"tag.location.default: 1\",\n"
+ "\"in_memory\" = \"false\",\n"
+ "\"storage_format\" = \"V2\",\n"
+ "\"disable_auto_compaction\" = \"false\"\n"
+ ");");
createTable("CREATE TABLE `test_sq_dj2` (\n"
+ " `c1` bigint(20) NULL,\n"
+ " `c2` bigint(20) NULL,\n"
+ " `c3` bigint(20) not NULL,\n"
+ " `k4` bigint(20) not NULL,\n"
+ " `k5` bigint(20) NULL\n"
+ ") ENGINE=OLAP\n"
+ "DUPLICATE KEY(`c1`)\n"
+ "COMMENT 'OLAP'\n"
+ "DISTRIBUTED BY HASH(`c2`) BUCKETS 1\n"
+ "PROPERTIES (\n"
+ "\"replication_allocation\" = \"tag.location.default: 1\",\n"
+ "\"in_memory\" = \"false\",\n"
+ "\"storage_format\" = \"V2\",\n"
+ "\"disable_auto_compaction\" = \"false\"\n"
+ ");");
}
// grouping sets
// grouping
@Test
public void test1() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 IN (SELECT c1 FROM test_sq_dj2) OR c1 < 10;");
}
@Test
public void test2() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 > (SELECT AVG(c1) FROM test_sq_dj2) OR c1 < 10;");
}
@Test
public void test2_1() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 > (SELECT AVG(c1) FROM test_sq_dj2 where test_sq_dj1.c1 = test_sq_dj2.c1);");
}
@Test
public void test2_2() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 > (SELECT AVG(c1) FROM test_sq_dj2 where test_sq_dj1.c1 = test_sq_dj2.c1) and c1 = 10;");
}
@Test
public void test2_3() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 = (SELECT AVG(c1) FROM test_sq_dj2 where test_sq_dj1.c1 = test_sq_dj2.c1) and c1 = 10;");
}
@Test
public void test3() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE EXISTS (SELECT c1 FROM test_sq_dj2 WHERE c1 = 10) OR c1 < 10");
}
@Test
public void test4() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 left semi join test_sq_dj2 on test_sq_dj1.c1 = test_sq_dj2.c1 WHERE c1 IN (SELECT c1 FROM test_sq_dj2) OR c1 < 10");
}
@Test
public void test5() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 IN (SELECT c1 FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1) OR c1 < 10");
}
/*
// Not support binaryOperator children at least one is in or exists subquery
@Test
public void test6() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE (c1 IN (SELECT c1 FROM test_sq_dj2)) != true");
}*/
/*
// Not support binaryOperator children at least one is in or exists subquery
@Test
public void test7() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE (c1 IN (SELECT c1 FROM test_sq_dj2) OR c1 < 10) != true");
}*/
@Test
public void test8() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 IN (SELECT c1 FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)"
+ " OR c1 < (SELECT sum(c1) FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)"
+ " OR exists (SELECT c1 FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)");
}
@Test
public void test9() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 IN (SELECT c1 FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)"
+ " OR c1 < (SELECT sum(c1) FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1) "
+ " AND c1 < (SELECT sum(c1) FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1) "
+ " AND EXISTS (SELECT c1 FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)");
}
@Test
public void test10() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE (c1 IN (SELECT c1 FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)"
+ " OR c1 < (SELECT sum(c1) FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)) "
+ " AND EXISTS (SELECT c1 FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)");
}
@Test
public void test11() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 != (SELECT sum(c1) FROM test_sq_dj2) OR c1 < 10;");
}
@Test
public void test12() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 != (SELECT sum(c1) FROM test_sq_dj2) and c1 = 1 OR c1 < 10;");
}
@Test
public void test13() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE (c1 != (SELECT sum(c1) FROM test_sq_dj2) and c1 = 1 OR c1 < 10) and c1 = 10 and c1 = 15;");
}
@Test
public void test14() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT CASE\n"
+ " WHEN (\n"
+ " SELECT COUNT(*) / 2\n"
+ " FROM test_sq_dj1\n"
+ " ) > c1 THEN (\n"
+ " SELECT AVG(c1)\n"
+ " FROM test_sq_dj1\n"
+ " )\n"
+ " ELSE (\n"
+ " SELECT SUM(c2)\n"
+ " FROM test_sq_dj1\n"
+ " )\n"
+ " END AS kk4\n"
+ " FROM test_sq_dj1 ;");
}
@Test
public void test14_1() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT CASE\n"
+ " WHEN exists (\n"
+ " SELECT COUNT(*) / 2\n"
+ " FROM test_sq_dj1\n"
+ " ) THEN (\n"
+ " SELECT AVG(c1)\n"
+ " FROM test_sq_dj1\n"
+ " )\n"
+ " ELSE (\n"
+ " SELECT SUM(c2)\n"
+ " FROM test_sq_dj1\n"
+ " )\n"
+ " END AS kk4\n"
+ " FROM test_sq_dj1 ;");
}
@Test
public void test15() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 != (SELECT sum(c1) FROM test_sq_dj2 where test_sq_dj1.c1 = test_sq_dj2.c1) and c1 = 10 and c1 = 15;");
}
@Test
public void test16() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 IN (SELECT c1 FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)"
+ " OR c1 < (SELECT sum(c1) FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)"
+ " OR exists (SELECT sum(c1) FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)");
}
@Test
public void test17() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE ((c1 != (SELECT sum(c1) FROM test_sq_dj2) and c1 = 1 OR c1 < 10) and c1 = 10 and c1 = 15)"
+ " and (c1 IN (SELECT c1 FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)"
+ " OR c1 < (SELECT sum(c1) FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1))"
+ " and exists (SELECT sum(c1) FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1);");
}
@Test
public void test18() {
PlanChecker.from(connectContext)
.checkPlannerResult("select * from test_sq_dj1 where test_sq_dj1.c1 != (select sum(c1) from test_sq_dj2 where test_sq_dj2.c3 = test_sq_dj1.c3) or c1 > 10");
}
@Test
public void test19() {
PlanChecker.from(connectContext)
.checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 IN (SELECT c1 FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)"
+ " OR c1 < (SELECT sum(c1) FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)"
+ " AND exists (SELECT c1 FROM test_sq_dj2 WHERE test_sq_dj1.c1 = test_sq_dj2.c1)");
}
}

View File

@ -228,20 +228,20 @@ public class PlanEqualsTest {
Lists.newArrayList(new EqualTo(
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()),
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))),
ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, logicalProperties, left, right);
ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, Optional.empty(), logicalProperties, left, right);
PhysicalHashJoin<Plan, Plan> expected = new PhysicalHashJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(new EqualTo(
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()),
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))),
ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, logicalProperties, left, right);
ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, Optional.empty(), logicalProperties, left, right);
Assertions.assertEquals(expected, actual);
PhysicalHashJoin<Plan, Plan> unexpected = new PhysicalHashJoin<>(JoinType.INNER_JOIN,
Lists.newArrayList(new EqualTo(
new SlotReference(new ExprId(2), "a", BigIntType.INSTANCE, false, Lists.newArrayList()),
new SlotReference(new ExprId(3), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))),
ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, logicalProperties, left, right);
ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, Optional.empty(), logicalProperties, left, right);
Assertions.assertNotEquals(unexpected, actual);
}

View File

@ -74,7 +74,7 @@ public class PlanToStringTest {
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))),
left, right);
Assertions.assertTrue(plan.toString().matches(
"LogicalJoin \\( type=INNER_JOIN, hashJoinConjuncts=\\[\\(a#\\d+ = b#\\d+\\)], otherJoinConjuncts=\\[] \\)"));
"LogicalJoin \\( type=INNER_JOIN, markJoinSlotReference=Optional.empty, hashJoinConjuncts=\\[\\(a#\\d+ = b#\\d+\\)], otherJoinConjuncts=\\[] \\)"));
}
@Test

View File

@ -42,6 +42,7 @@ import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@ -113,7 +114,7 @@ public class LogicalPlanBuilder {
public LogicalPlanBuilder join(LogicalPlan right, JoinType joinType, List<Expression> hashJoinConjuncts,
List<Expression> otherJoinConjucts) {
LogicalJoin<LogicalPlan, LogicalPlan> join = new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjucts,
JoinHint.NONE, this.plan, right);
JoinHint.NONE, Optional.empty(), this.plan, right);
return from(join);
}