[feat](nereids) support null safe eq runtime filter (FE part) (#31655)

be part has been merged in #31754
This commit is contained in:
minghong
2024-03-07 14:29:25 +08:00
committed by yiguolei
parent fa411f88df
commit db389d7d4e
10 changed files with 174 additions and 41 deletions

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.stats.ExpressionEstimation;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -268,25 +269,22 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts().stream().collect(Collectors.toList());
boolean buildSideContainsConsumer = hasCTEConsumerDescendant((PhysicalPlan) join.right());
for (int i = 0; i < hashJoinConjuncts.size(); i++) {
// BE do not support RF generated from NullSafeEqual, skip them
if (hashJoinConjuncts.get(i) instanceof EqualTo) {
EqualTo equalTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder(
(EqualTo) hashJoinConjuncts.get(i), join.left().getOutputSet()));
for (TRuntimeFilterType type : legalTypes) {
//bitmap rf is generated by nested loop join.
if (type == TRuntimeFilterType.BITMAP) {
continue;
}
long buildSideNdv = getBuildSideNdv(join, equalTo);
Pair<PhysicalRelation, Slot> pair = ctx.getAliasTransferMap().get(equalTo.right());
// CteConsumer is not allowed to generate RF in order to avoid RF cycle.
if ((pair == null && buildSideContainsConsumer)
|| (pair != null && pair.first instanceof PhysicalCTEConsumer)) {
continue;
}
join.pushDownRuntimeFilter(context, generator, join, equalTo.right(),
equalTo.left(), type, buildSideNdv, i);
EqualPredicate equalTo = JoinUtils.swapEqualToForChildrenOrder(
(EqualPredicate) hashJoinConjuncts.get(i), join.left().getOutputSet());
for (TRuntimeFilterType type : legalTypes) {
//bitmap rf is generated by nested loop join.
if (type == TRuntimeFilterType.BITMAP) {
continue;
}
long buildSideNdv = getBuildSideNdv(join, equalTo);
Pair<PhysicalRelation, Slot> pair = ctx.getAliasTransferMap().get(equalTo.right());
// CteConsumer is not allowed to generate RF in order to avoid RF cycle.
if ((pair == null && buildSideContainsConsumer)
|| (pair != null && pair.first instanceof PhysicalCTEConsumer)) {
continue;
}
join.pushDownRuntimeFilter(context, generator, join, equalTo.right(),
equalTo.left(), type, buildSideNdv, i);
}
}
return join;

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.expression.rules.CaseWhenToIf;
import org.apache.doris.nereids.rules.expression.rules.DateFunctionRewrite;
import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
import org.apache.doris.nereids.rules.expression.rules.NullSafeEqualToEqual;
import org.apache.doris.nereids.rules.expression.rules.OrToIn;
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison;
@ -48,7 +49,8 @@ public class ExpressionOptimization extends ExpressionRewrite {
OrToIn.INSTANCE,
ArrayContainToArrayOverlap.INSTANCE,
CaseWhenToIf.INSTANCE,
TopnToMax.INSTANCE
TopnToMax.INSTANCE,
NullSafeEqualToEqual.INSTANCE
);
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);

View File

@ -0,0 +1,62 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
/**
* convert "<=>" to "=", if any side is not nullable
* convert "A <=> null" to "A is null"
*/
public class NullSafeEqualToEqual extends DefaultExpressionRewriter<ExpressionRewriteContext> implements
ExpressionRewriteRule<ExpressionRewriteContext> {
public static final NullSafeEqualToEqual INSTANCE = new NullSafeEqualToEqual();
@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
return expr.accept(this, null);
}
@Override
public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, ExpressionRewriteContext ctx) {
if (nullSafeEqual.left() instanceof NullLiteral) {
if (nullSafeEqual.right().nullable()) {
return new IsNull(nullSafeEqual.right());
} else {
return BooleanLiteral.FALSE;
}
} else if (nullSafeEqual.right() instanceof NullLiteral) {
if (nullSafeEqual.left().nullable()) {
return new IsNull(nullSafeEqual.left());
} else {
return BooleanLiteral.FALSE;
}
} else if (!nullSafeEqual.left().nullable() || !nullSafeEqual.right().nullable()) {
return new EqualTo(nullSafeEqual.left(), nullSafeEqual.right());
}
return nullSafeEqual;
}
}

View File

@ -63,7 +63,8 @@ public class FindHashConditionForJoin extends OneRewriteRuleFactory {
}
List<Expression> combinedHashJoinConjuncts = Streams
.concat(join.getHashJoinConjuncts().stream(), extractedHashJoinConjuncts.stream())
.concat(join.getHashJoinConjuncts().stream(),
extractedHashJoinConjuncts.stream())
.distinct()
.collect(ImmutableList.toImmutableList());
JoinType joinType = join.getJoinType();

View File

@ -24,7 +24,9 @@ import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.nereids.trees.plans.Explainable;
@ -131,6 +133,14 @@ public abstract class AbstractPhysicalPlan extends AbstractPlan implements Physi
ctx.setTargetsOnScanNode(ctx.getAliasTransferPair(probeSlot).first, scanSlot);
}
} else {
// null safe equal runtime filter only support bloom filter
EqualPredicate eq = (EqualPredicate) builderNode.getHashJoinConjuncts().get(exprOrder);
if (eq instanceof NullSafeEqual && type == TRuntimeFilterType.IN_OR_BLOOM) {
type = TRuntimeFilterType.BLOOM;
}
if (eq instanceof NullSafeEqual && type != TRuntimeFilterType.BLOOM) {
return false;
}
filter = new RuntimeFilter(generator.getNextId(),
src, ImmutableList.of(scanSlot), ImmutableList.of(probeExpr),
type, exprOrder, builderNode, buildSideNdv,

View File

@ -199,12 +199,6 @@ public class HashJoinNode extends JoinNodeBase {
for (Expr eqJoinPredicate : eqJoinConjuncts) {
Preconditions.checkArgument(eqJoinPredicate instanceof BinaryPredicate);
BinaryPredicate eqJoin = (BinaryPredicate) eqJoinPredicate;
if (eqJoin.getOp().equals(BinaryPredicate.Operator.EQ_FOR_NULL)) {
Preconditions.checkArgument(eqJoin.getChildren().size() == 2);
if (!eqJoin.getChild(0).isNullable() || !eqJoin.getChild(1).isNullable()) {
eqJoin.setOp(BinaryPredicate.Operator.EQ);
}
}
this.eqJoinConjuncts.add(eqJoin);
}
this.distrMode = DistributionMode.NONE;

View File

@ -230,6 +230,15 @@ public final class RuntimeFilter {
}
tFilter.setOptRemoteRf(hasRemoteTargets);
tFilter.setBloomFilterSizeCalculatedByNdv(bloomFilterSizeCalculatedByNdv);
if (builderNode instanceof HashJoinNode) {
HashJoinNode join = (HashJoinNode) builderNode;
BinaryPredicate eq = join.getEqJoinConjuncts().get(exprOrder);
if (eq.getOp().equals(BinaryPredicate.Operator.EQ_FOR_NULL)) {
tFilter.setNullAware(true);
} else {
tFilter.setNullAware(false);
}
}
return tFilter;
}

View File

@ -94,6 +94,7 @@ import org.apache.doris.thrift.TPipelineFragmentParams;
import org.apache.doris.thrift.TPipelineFragmentParamsList;
import org.apache.doris.thrift.TPipelineInstanceParams;
import org.apache.doris.thrift.TPipelineWorkloadGroup;
import org.apache.doris.thrift.TPlanFragment;
import org.apache.doris.thrift.TPlanFragmentDestination;
import org.apache.doris.thrift.TPlanFragmentExecParams;
import org.apache.doris.thrift.TQueryGlobals;
@ -3703,6 +3704,7 @@ public class Coordinator implements CoordInterface {
Map<TNetworkAddress, TPipelineFragmentParams> res = new HashMap();
Map<TNetworkAddress, Integer> instanceIdx = new HashMap();
TPlanFragment fragmentThrift = fragment.toThrift();
for (int i = 0; i < instanceExecParams.size(); ++i) {
final FInstanceExecParam instanceExecParam = instanceExecParams.get(i);
Map<Integer, List<TScanRangeParams>> scanRanges = instanceExecParam.perNodeScanRanges;
@ -3728,7 +3730,7 @@ public class Coordinator implements CoordInterface {
params.query_options.setMemLimit(memLimit);
params.setSendQueryStatisticsWithEveryBatch(
fragment.isTransferQueryStatisticsWithEveryBatch());
params.setFragment(fragment.toThrift());
params.setFragment(fragmentThrift);
params.setLocalParams(Lists.newArrayList());
if (tWorkloadGroups != null) {
params.setWorkloadGroups(tWorkloadGroups);