[feature](Nereids) support basic runtime filter (#12182)

This PR add runtime filter to Nereids planner. Now only support push through join node and scan node.
TODO:
1. current support inner join, cross join, right outer join, and will support other join type in future.
2. translate left outer join to inner join if there are inner join ancestors.
3. some complex situation cannot be handled now, see more details in test case: testPushDownThroughJoin.
4. support src key is aggregate group key.
This commit is contained in:
mch_ucchi
2022-09-16 02:21:01 +08:00
committed by GitHub
parent 0daa25d9a9
commit a63cdc8a7c
22 changed files with 916 additions and 14 deletions

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.jobs.scheduler.JobScheduler;
import org.apache.doris.nereids.jobs.scheduler.JobStack;
import org.apache.doris.nereids.jobs.scheduler.SimpleJobScheduler;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleFactory;
@ -56,6 +57,8 @@ public class CascadesContext {
// subqueryExprIsAnalyzed: whether the subquery has been analyzed.
private Map<SubqueryExpr, Boolean> subqueryExprIsAnalyzed;
private RuntimeFilterContext runtimeFilterContext;
/**
* Constructor of OptimizerContext.
*
@ -70,6 +73,7 @@ public class CascadesContext {
this.jobScheduler = new SimpleJobScheduler();
this.currentJobContext = new JobContext(this, PhysicalProperties.ANY, Double.MAX_VALUE);
this.subqueryExprIsAnalyzed = new HashMap<>();
this.runtimeFilterContext = new RuntimeFilterContext(getConnectContext().getSessionVariable());
}
public static CascadesContext newContext(StatementContext statementContext, Plan initPlan) {
@ -124,6 +128,10 @@ public class CascadesContext {
return currentJobContext;
}
public RuntimeFilterContext getRuntimeFilterContext() {
return runtimeFilterContext;
}
public void setCurrentJobContext(JobContext currentJobContext) {
this.currentJobContext = currentJobContext;
}

View File

@ -24,8 +24,8 @@ import org.apache.doris.common.UserException;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor;
import org.apache.doris.nereids.jobs.batch.OptimizeRulesJob;
import org.apache.doris.nereids.jobs.batch.RewriteJob;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.planner.Planner;
import org.apache.doris.planner.RuntimeFilter;
import org.apache.doris.planner.ScanNode;
import com.google.common.annotations.VisibleForTesting;
@ -71,7 +72,7 @@ public class NereidsPlanner extends Planner {
PhysicalPlan physicalPlan = plan(logicalPlanAdapter.getLogicalPlan(), PhysicalProperties.ANY);
PhysicalPlanTranslator physicalPlanTranslator = new PhysicalPlanTranslator();
PlanTranslatorContext planTranslatorContext = new PlanTranslatorContext();
PlanTranslatorContext planTranslatorContext = new PlanTranslatorContext(cascadesContext);
PlanFragment root = physicalPlanTranslator.translatePlan(physicalPlan, planTranslatorContext);
scanNodeList = planTranslatorContext.getScanNodes();
@ -146,7 +147,7 @@ public class NereidsPlanner extends Planner {
* Logical plan rewrite based on a series of heuristic rules.
*/
private void rewrite() {
new RewriteJob(cascadesContext).execute();
new NereidsRewriteJobExecutor(cascadesContext).execute();
}
private void deriveStats() {
@ -210,4 +211,14 @@ public class NereidsPlanner extends Planner {
public void appendTupleInfo(StringBuilder str) {
str.append(descTable.getExplainString());
}
@Override
public List<RuntimeFilter> getRuntimeFilters() {
return cascadesContext.getRuntimeFilterContext().getLegacyFilters();
}
@VisibleForTesting
public CascadesContext getCascadesContext() {
return cascadesContext;
}
}

View File

@ -331,6 +331,12 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
Utils.execWithUncheckedException(olapScanNode::init);
context.addScanNode(olapScanNode);
// translate runtime filter
context.getRuntimeTranslator().ifPresent(
runtimeFilterGenerator -> runtimeFilterGenerator.getTargetOnScanNode(olapScan.getId()).forEach(
expr -> runtimeFilterGenerator.translateRuntimeFilterTarget(expr, olapScanNode, context)
)
);
// Create PlanFragment
DataPartition dataPartition = DataPartition.RANDOM;
if (olapScan.getDistributionSpec() instanceof DistributionSpecHash) {
@ -522,6 +528,10 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
hashJoinNode.setvIntermediateTupleDescList(Lists.newArrayList(outputDescriptor));
hashJoinNode.setvOutputTupleDesc(outputDescriptor);
hashJoinNode.setvSrcToOutputSMap(srcToOutput);
// translate runtime filter
context.getRuntimeTranslator().ifPresent(runtimeFilterTranslator -> runtimeFilterTranslator
.getRuntimeFilterOfHashJoinNode(physicalHashJoin)
.forEach(filter -> runtimeFilterTranslator.createLegacyRuntimeFilter(filter, hashJoinNode, context)));
return currentFragment;
}

View File

@ -25,6 +25,7 @@ import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.analysis.TupleId;
import org.apache.doris.catalog.Column;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.planner.PlanFragment;
@ -33,6 +34,7 @@ import org.apache.doris.planner.PlanNode;
import org.apache.doris.planner.PlanNodeId;
import org.apache.doris.planner.ScanNode;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
@ -49,6 +51,8 @@ public class PlanTranslatorContext {
private final DescriptorTable descTable = new DescriptorTable();
private final RuntimeFilterTranslator translator;
/**
* index from Nereids' slot to legacy slot.
*/
@ -65,6 +69,15 @@ public class PlanTranslatorContext {
private final IdGenerator<PlanNodeId> nodeIdGenerator = PlanNodeId.createGenerator();
public PlanTranslatorContext(CascadesContext ctx) {
this.translator = new RuntimeFilterTranslator(ctx.getRuntimeFilterContext());
}
@VisibleForTesting
public PlanTranslatorContext() {
translator = null;
}
public List<PlanFragment> getPlanFragments() {
return planFragments;
}
@ -73,6 +86,10 @@ public class PlanTranslatorContext {
return descTable.createTupleDescriptor();
}
public Optional<RuntimeFilterTranslator> getRuntimeTranslator() {
return Optional.ofNullable(translator);
}
public PlanFragmentId nextFragmentId() {
return fragmentIdGenerator.getNextId();
}

View File

@ -0,0 +1,99 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.glue.translator;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter;
import org.apache.doris.planner.HashJoinNode;
import org.apache.doris.planner.HashJoinNode.DistributionMode;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.RuntimeFilter.RuntimeFilterTarget;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import java.util.List;
/**
* translate runtime filter
*/
public class RuntimeFilterTranslator {
private final RuntimeFilterContext context;
public RuntimeFilterTranslator(RuntimeFilterContext context) {
this.context = context;
context.generatePhysicalHashJoinToRuntimeFilter();
}
public List<RuntimeFilter> getRuntimeFilterOfHashJoinNode(PhysicalHashJoin join) {
return context.getRuntimeFilterOnHashJoinNode(join);
}
public List<Slot> getTargetOnScanNode(RelationId id) {
return context.getTargetOnOlapScanNodeMap().getOrDefault(id, Collections.emptyList());
}
/**
* translate runtime filter target.
* @param node olap scan node
* @param ctx plan translator context
*/
public void translateRuntimeFilterTarget(Slot slot, OlapScanNode node, PlanTranslatorContext ctx) {
context.setKVInNormalMap(context.getExprIdToOlapScanNodeSlotRef(),
slot.getExprId(), ctx.findSlotRef(slot.getExprId()));
context.setKVInNormalMap(context.getScanNodeOfLegacyRuntimeFilterTarget(), slot, node);
}
/**
* generate legacy runtime filter
* @param filter nereids runtime filter
* @param node hash join node
* @param ctx plan translator context
*/
public void createLegacyRuntimeFilter(RuntimeFilter filter, HashJoinNode node, PlanTranslatorContext ctx) {
SlotRef src = ctx.findSlotRef(filter.getSrcExpr().getExprId());
SlotRef target = context.getExprIdToOlapScanNodeSlotRef().get(filter.getTargetExpr().getExprId());
org.apache.doris.planner.RuntimeFilter origFilter
= org.apache.doris.planner.RuntimeFilter.fromNereidsRuntimeFilter(
filter.getId(), node, src, filter.getExprOrder(), target,
ImmutableMap.of(target.getDesc().getParent().getId(), ImmutableList.of(target.getSlotId())),
filter.getType(), context.getLimits());
origFilter.setIsBroadcast(node.getDistributionMode() == DistributionMode.BROADCAST);
filter.setFinalized();
OlapScanNode scanNode = context.getScanNodeOfLegacyRuntimeFilterTarget().get(filter.getTargetExpr());
origFilter.addTarget(new RuntimeFilterTarget(
scanNode,
target,
true,
scanNode.getFragmentId().equals(node.getFragmentId())));
context.getLegacyFilters().add(finalize(origFilter));
}
private org.apache.doris.planner.RuntimeFilter finalize(org.apache.doris.planner.RuntimeFilter origFilter) {
origFilter.markFinalized();
origFilter.assignToPlanNodes();
origFilter.extractTargetsPosition();
return origFilter;
}
}

View File

@ -37,14 +37,14 @@ import com.google.common.collect.ImmutableList;
/**
* Apply rules to normalize expressions.
*/
public class RewriteJob extends BatchRulesJob {
public class NereidsRewriteJobExecutor extends BatchRulesJob {
/**
* Constructor.
*
* @param cascadesContext context for applying rules.
*/
public RewriteJob(CascadesContext cascadesContext) {
public NereidsRewriteJobExecutor(CascadesContext cascadesContext) {
super(cascadesContext);
ImmutableList<Job> jobs = new ImmutableList.Builder<Job>()
/*

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import java.util.List;
import java.util.Objects;
@ -43,8 +44,15 @@ public class PlanPostProcessors {
return resultPlan;
}
/**
* get processors
*/
public List<PlanPostProcessor> getProcessors() {
// add processor if we need
return ImmutableList.of();
Builder<PlanPostProcessor> builder = ImmutableList.builder();
if (cascadesContext.getConnectContext().getSessionVariable().isEnableNereidsRuntimeFilter()) {
builder.add(new RuntimeFilterGenerator());
}
return builder.build();
}
}

View File

@ -0,0 +1,181 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.processor.post;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.RuntimeFilterGenerator.FilterSizeLimits;
import org.apache.doris.planner.RuntimeFilterId;
import org.apache.doris.qe.SessionVariable;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.jetbrains.annotations.NotNull;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* runtime filter context used at post process and translation.
*/
public class RuntimeFilterContext {
private final IdGenerator<RuntimeFilterId> generator = RuntimeFilterId.createGenerator();
// exprId of target to runtime filter.
private final Map<ExprId, List<RuntimeFilter>> targetExprIdToFilter = Maps.newHashMap();
// olap scan node that contains target of a runtime filter.
private final Map<RelationId, List<Slot>> targetOnOlapScanNodeMap = Maps.newHashMap();
private final List<org.apache.doris.planner.RuntimeFilter> legacyFilters = Lists.newArrayList();
// exprId to olap scan node slotRef because the slotRef will be changed when translating.
private final Map<ExprId, SlotRef> exprIdToOlapScanNodeSlotRef = Maps.newHashMap();
private final Map<PhysicalHashJoin, List<RuntimeFilter>> runtimeFilterOnHashJoinNode = Maps.newHashMap();
// Alias's child to itself.
private final Map<Slot, NamedExpression> aliasChildToSelf = Maps.newHashMap();
private final Map<Slot, OlapScanNode> scanNodeOfLegacyRuntimeFilterTarget = Maps.newHashMap();
private final SessionVariable sessionVariable;
private final FilterSizeLimits limits;
public RuntimeFilterContext(SessionVariable sessionVariable) {
this.sessionVariable = sessionVariable;
this.limits = new FilterSizeLimits(sessionVariable);
}
public SessionVariable getSessionVariable() {
return sessionVariable;
}
public FilterSizeLimits getLimits() {
return limits;
}
public void setTargetExprIdToFilters(ExprId id, RuntimeFilter... filters) {
Preconditions.checkArgument(Arrays.stream(filters)
.allMatch(filter -> filter.getTargetExpr().getExprId() == id));
this.targetExprIdToFilter.computeIfAbsent(id, k -> Lists.newArrayList())
.addAll(Arrays.asList(filters));
}
public List<RuntimeFilter> getFiltersByTargetExprId(ExprId id) {
return targetExprIdToFilter.get(id);
}
public void removeFilters(ExprId id) {
targetExprIdToFilter.remove(id);
}
public void setTargetsOnScanNode(RelationId id, Slot... slots) {
this.targetOnOlapScanNodeMap.computeIfAbsent(id, k -> Lists.newArrayList())
.addAll(Arrays.asList(slots));
}
public <K, V> void setKVInNormalMap(@NotNull Map<K, V> map, K key, V value) {
map.put(key, value);
}
public Map<ExprId, SlotRef> getExprIdToOlapScanNodeSlotRef() {
return exprIdToOlapScanNodeSlotRef;
}
public Map<Slot, NamedExpression> getAliasChildToSelf() {
return aliasChildToSelf;
}
public Map<Slot, OlapScanNode> getScanNodeOfLegacyRuntimeFilterTarget() {
return scanNodeOfLegacyRuntimeFilterTarget;
}
public List<RuntimeFilter> getRuntimeFilterOnHashJoinNode(PhysicalHashJoin join) {
return runtimeFilterOnHashJoinNode.getOrDefault(join, Collections.emptyList());
}
public void generatePhysicalHashJoinToRuntimeFilter() {
targetExprIdToFilter.values().forEach(filters -> filters.forEach(filter -> runtimeFilterOnHashJoinNode
.computeIfAbsent(filter.getBuilderNode(), k -> Lists.newArrayList()).add(filter)));
}
public Map<ExprId, List<RuntimeFilter>> getTargetExprIdToFilter() {
return targetExprIdToFilter;
}
public Map<RelationId, List<Slot>> getTargetOnOlapScanNodeMap() {
return targetOnOlapScanNodeMap;
}
public List<org.apache.doris.planner.RuntimeFilter> getLegacyFilters() {
return legacyFilters;
}
public void setLegacyFilter(org.apache.doris.planner.RuntimeFilter filter) {
this.legacyFilters.add(filter);
}
public <K, V> boolean checkExistKey(@NotNull Map<K, V> map, K key) {
return map.containsKey(key);
}
/**
* get nereids runtime filters
* @return nereids runtime filters
*/
@VisibleForTesting
public List<RuntimeFilter> getNereidsRuntimeFilter() {
List<RuntimeFilter> filters = getTargetExprIdToFilter().values().stream()
.reduce(Lists.newArrayList(), (l, r) -> {
l.addAll(r);
return l;
});
filters.sort((a, b) -> a.getId().compareTo(b.getId()));
return filters;
}
/**
* get the slot list of the same olap scan node of the input slot.
* @param slot slot
* @return slot list
*/
public List<NamedExpression> getSlotListOfTheSameSlotAtOlapScanNode(Slot slot) {
ImmutableList.Builder<NamedExpression> builder = ImmutableList.builder();
NamedExpression expr = slot;
do {
builder.add(expr);
expr = aliasChildToSelf.get(expr.toSlot());
} while (expr != null);
return builder.build();
}
}

View File

@ -0,0 +1,165 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.processor.post;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter;
import org.apache.doris.planner.RuntimeFilterId;
import org.apache.doris.thrift.TRuntimeFilterType;
import com.google.common.collect.ImmutableSet;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
/**
* generate runtime filter
*/
public class RuntimeFilterGenerator extends PlanPostProcessor {
private final IdGenerator<RuntimeFilterId> generator = RuntimeFilterId.createGenerator();
private final ImmutableSet<JoinType> deniedJoinType = ImmutableSet.of(
JoinType.LEFT_ANTI_JOIN,
JoinType.RIGHT_ANTI_JOIN,
JoinType.FULL_OUTER_JOIN,
JoinType.LEFT_OUTER_JOIN
);
/**
* the runtime filter generator run at the phase of post process and plan translation of nereids planner.
* post process:
* first step: if encounter supported join type, generate nereids runtime filter for all the hash conjunctions
* and make association from exprId of the target slot references to the runtime filter. or delete the runtime
* filter whose target slot reference is one of the output slot references of the left child of the physical join as
* the runtime filter.
* second step: if encounter project, collect the association of its child and it for pushing down through
* the project node.
* plan translation:
* third step: generate nereids runtime filter target at olap scan node fragment.
* forth step: generate legacy runtime filter target and runtime filter at hash join node fragment.
*/
// TODO: current support inner join, cross join, right outer join, and will support more join type.
@Override
public PhysicalPlan visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? extends Plan> join,
CascadesContext context) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
if (deniedJoinType.contains(join.getJoinType())) {
/* TODO: translate left outer join to inner join if there are inner join ancestors
* if it has encountered inner join, like
* a=b
* / \
* / \
* / \
* / \
* left join-->a=c b
* / \
* / \
* / \
* / \
* a c
* runtime filter whose src expr is b can take effect on c.
* but now checking the inner join is unsupported. we may support it at later version.
*/
join.getOutput().forEach(slot -> ctx.removeFilters(slot.getExprId()));
} else {
List<TRuntimeFilterType> legalTypes = Arrays.stream(TRuntimeFilterType.values()).filter(type ->
(type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0)
.collect(Collectors.toList());
AtomicInteger cnt = new AtomicInteger();
join.getHashJoinConjuncts().stream()
.map(EqualTo.class::cast)
// TODO: we will support it in later version.
/*.peek(expr -> {
// target is always the expr at the two side of equal of hash conjunctions.
// TODO: some complex situation cannot be handled now, see testPushDownThroughJoin.
List<SlotReference> slots = expr.children().stream().filter(SlotReference.class::isInstance)
.map(SlotReference.class::cast).collect(Collectors.toList());
if (slots.size() != 2
|| !(ctx.checkExistKey(ctx.getTargetExprIdToFilter(), slots.get(0).getExprId())
|| ctx.checkExistKey(ctx.getTargetExprIdToFilter(), slots.get(1).getExprId()))) {
return;
}
int tag = ctx.checkExistKey(ctx.getTargetExprIdToFilter(), slots.get(0).getExprId()) ? 0 : 1;
// generate runtime filter to associated expr. for example, a = b and a = c, RF b -> a can
// generate RF b -> c
List<RuntimeFilter> copiedRuntimeFilter = ctx.getFiltersByTargetExprId(slots.get(tag)
.getExprId()).stream()
.map(filter -> new RuntimeFilter(generator.getNextId(), filter.getSrcExpr(),
slots.get(tag ^ 1), filter.getType(), filter.getExprOrder(), join))
.collect(Collectors.toList());
ctx.setTargetExprIdToFilters(slots.get(tag ^ 1).getExprId(),
copiedRuntimeFilter.toArray(new RuntimeFilter[0]));
})*/
.forEach(expr -> legalTypes.stream()
.map(type -> RuntimeFilter.createRuntimeFilter(generator.getNextId(), expr,
type, cnt.getAndIncrement(), join))
.filter(Objects::nonNull)
.forEach(filter ->
ctx.setTargetExprIdToFilters(filter.getTargetExpr().getExprId(), filter)));
}
join.left().accept(this, context);
join.right().accept(this, context);
return join;
}
// TODO: support src key is agg slot.
@Override
public PhysicalPlan visitPhysicalProject(PhysicalProject<? extends Plan> project, CascadesContext context) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
project.getProjects().stream().filter(Alias.class::isInstance)
.map(Alias.class::cast)
.filter(expr -> expr.child() instanceof SlotReference)
.forEach(expr -> ctx.setKVInNormalMap(ctx.getAliasChildToSelf(), ((SlotReference) expr.child()), expr));
project.child().accept(this, context);
return project;
}
@Override
public PhysicalOlapScan visitPhysicalOlapScan(PhysicalOlapScan scan, CascadesContext context) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
scan.getOutput().stream()
.filter(slot -> ctx.getSlotListOfTheSameSlotAtOlapScanNode(slot).stream()
.filter(expr -> ctx.checkExistKey(ctx.getTargetExprIdToFilter(), expr.getExprId()))
.peek(expr -> {
if (expr.getExprId() == slot.getExprId()) {
return;
}
List<RuntimeFilter> filters = ctx.getFiltersByTargetExprId(expr.getExprId());
ctx.removeFilters(expr.getExprId());
filters.forEach(filter -> filter.setTargetSlot(slot));
ctx.setKVInNormalMap(ctx.getTargetExprIdToFilter(), slot.getExprId(), filters);
})
.count() > 0)
.forEach(slot -> ctx.setTargetsOnScanNode(scan.getId(), slot));
return scan;
}
}

View File

@ -0,0 +1,127 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.physical;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.planner.RuntimeFilterId;
import org.apache.doris.thrift.TRuntimeFilterType;
/**
* runtime filter
*/
public class RuntimeFilter {
private final Slot srcSlot;
private Slot targetSlot;
private final RuntimeFilterId id;
private final TRuntimeFilterType type;
private final int exprOrder;
private boolean finalized = false;
private PhysicalHashJoin builderNode;
/**
* constructor
*/
public RuntimeFilter(RuntimeFilterId id, Slot src, Slot target, TRuntimeFilterType type,
int exprOrder, PhysicalHashJoin builderNode) {
this.id = id;
this.srcSlot = src;
this.targetSlot = target;
this.type = type;
this.exprOrder = exprOrder;
this.builderNode = builderNode;
}
/**
* create RF
*/
public static RuntimeFilter createRuntimeFilter(RuntimeFilterId id, EqualTo conjunction,
TRuntimeFilterType type, int exprOrder, PhysicalHashJoin node) {
Pair<Expression, Expression> srcs = checkAndMaybeSwapChild(conjunction, node);
if (srcs == null) {
return null;
}
return new RuntimeFilter(id, ((SlotReference) srcs.second), ((SlotReference) srcs.first), type, exprOrder,
node);
}
private static Pair<Expression, Expression> checkAndMaybeSwapChild(EqualTo expr,
PhysicalHashJoin join) {
if (expr.children().stream().anyMatch(Literal.class::isInstance)) {
return null;
}
if (expr.child(0).equals(expr.child(1))) {
return null;
}
if (!expr.children().stream().allMatch(SlotReference.class::isInstance)) {
return null;
}
//current we assume that there are certainly different slot reference in equal to.
//they are not from the same relation.
int exchangeTag = join.child(0).getOutput().stream().anyMatch(slot -> slot.getExprId().equals(
((SlotReference) expr.child(1)).getExprId())) ? 1 : 0;
return Pair.of(expr.child(exchangeTag), expr.child(1 ^ exchangeTag));
}
public Slot getSrcExpr() {
return srcSlot;
}
public Slot getTargetExpr() {
return targetSlot;
}
public RuntimeFilterId getId() {
return id;
}
public TRuntimeFilterType getType() {
return type;
}
public int getExprOrder() {
return exprOrder;
}
public PhysicalHashJoin getBuilderNode() {
return builderNode;
}
public void setTargetSlot(Slot targetSlot) {
this.targetSlot = targetSlot;
}
public boolean isUninitialized() {
return !finalized;
}
public void setFinalized() {
this.finalized = true;
}
}

View File

@ -85,6 +85,11 @@ public class OriginalPlanner extends Planner {
createPlanFragments(queryStmt, analyzer, queryOptions);
}
@Override
public List<RuntimeFilter> getRuntimeFilters() {
return analyzer.getAssignedRuntimeFilter();
}
/**
*/
private void setResultExprScale(Analyzer analyzer, ArrayList<Expr> outputExprs) {

View File

@ -92,4 +92,6 @@ public abstract class Planner {
public abstract DescriptorTable getDescTable();
public abstract List<RuntimeFilter> getRuntimeFilters();
}

View File

@ -143,6 +143,13 @@ public final class RuntimeFilter {
calculateFilterSize(filterSizeLimits);
}
// only for nereids planner
public static RuntimeFilter fromNereidsRuntimeFilter(RuntimeFilterId id, HashJoinNode node, Expr srcExpr,
int exprOrder, Expr origTargetExpr, Map<TupleId, List<SlotId>> targetSlots,
TRuntimeFilterType type, RuntimeFilterGenerator.FilterSizeLimits filterSizeLimits) {
return new RuntimeFilter(id, node, srcExpr, exprOrder, origTargetExpr, targetSlots, type, filterSizeLimits);
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof RuntimeFilter)) {

View File

@ -263,7 +263,7 @@ public class Coordinator {
this.nextInstanceId = new TUniqueId();
nextInstanceId.setHi(queryId.hi);
nextInstanceId.setLo(queryId.lo + 1);
this.assignedRuntimeFilters = analyzer.getAssignedRuntimeFilter();
this.assignedRuntimeFilters = planner.getRuntimeFilters();
}
// Used for broker load task/export task/update coordinator

View File

@ -197,6 +197,8 @@ public class SessionVariable implements Serializable, Writable {
public static final String ENABLE_FALLBACK_TO_ORIGINAL_PLANNER = "enable_fallback_to_original_planner";
public static final String ENABLE_NEREIDS_RUNTIME_FILTER = "enable_nereids_runtime_filter";
public static final String ENABLE_NEREIDS_REORDER_TO_ELIMINATE_CROSS_JOIN =
"enable_nereids_reorder_to_eliminate_cross_join";
@ -515,6 +517,9 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = ENABLE_NEREIDS_PLANNER)
private boolean enableNereidsPlanner = false;
@VariableMgr.VarAttr(name = ENABLE_NEREIDS_RUNTIME_FILTER)
private boolean enableNereidsRuntimeFilter = true;
@VariableMgr.VarAttr(name = ENABLE_NEREIDS_REORDER_TO_ELIMINATE_CROSS_JOIN)
private boolean enableNereidsReorderToEliminateCrossJoin = true;
@ -1080,6 +1085,14 @@ public class SessionVariable implements Serializable, Writable {
this.enableNereidsPlanner = enableNereidsPlanner;
}
public boolean isEnableNereidsRuntimeFilter() {
return enableNereidsRuntimeFilter;
}
public void setEnableNereidsRuntimeFilter(boolean enableNereidsRuntimeFilter) {
this.enableNereidsRuntimeFilter = enableNereidsRuntimeFilter;
}
public boolean isEnableNereidsReorderToEliminateCrossJoin() {
return enableNereidsReorderToEliminateCrossJoin;
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.datasets.tpch;
import org.apache.doris.nereids.analyzer.Unbound;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.utframe.TestWithFeService;
@ -28,6 +29,12 @@ import org.junit.jupiter.api.Assertions;
import java.util.List;
public abstract class AnalyzeCheckTestBase extends TestWithFeService {
@Override
public void runBeforeEach() throws Exception {
NamedExpressionUtil.clear();
}
protected void checkAnalyze(String sql) {
LogicalPlan analyzed = analyze(sql);
Assertions.assertTrue(checkBound(analyzed));

View File

@ -48,7 +48,7 @@ import java.util.stream.Collectors;
public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
public void runBeforeAll() throws Exception {
createDatabase("test_having");
connectContext.setDatabase("default_cluster:test_having");
createTables(
@ -76,7 +76,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
}
@Override
protected void runBeforeEach() throws Exception {
public void runBeforeEach() throws Exception {
NamedExpressionUtil.clear();
}

View File

@ -0,0 +1,204 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.postprocess;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.datasets.ssb.SSBTestBase;
import org.apache.doris.nereids.datasets.ssb.SSBUtils;
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter;
import org.apache.doris.planner.PlanFragment;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
public class RuntimeFilterTest extends SSBTestBase {
@Override
public void runBeforeAll() throws Exception {
super.runBeforeAll();
connectContext.getSessionVariable().setEnableNereidsRuntimeFilter(true);
connectContext.getSessionVariable().setRuntimeFilterType(8);
}
@Test
public void testGenerateRuntimeFilter() throws AnalysisException {
String sql = "SELECT * FROM lineorder JOIN customer on c_custkey = lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 1);
}
@Test
public void testGenerateRuntimeFilterByIllegalSrcExpr() throws AnalysisException {
String sql = "SELECT * FROM lineorder JOIN customer on c_custkey = c_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertEquals(0, filters.size());
}
@Test
public void testComplexExpressionToRuntimeFilter() throws AnalysisException {
String sql
= "SELECT * FROM supplier JOIN customer on c_name = s_name and s_city = c_city and s_nation = c_nation";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 3);
}
@Test
public void testNestedJoinGenerateRuntimeFilter() throws AnalysisException {
String sql = SSBUtils.Q4_1;
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 4);
}
@Test
public void testSubTreeInUnsupportedJoinType() throws AnalysisException {
String sql = "select c_custkey"
+ " from (select lo_custkey from lineorder inner join dates on lo_orderdate = d_datekey) a"
+ " left outer join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 3);
}
@Test
public void testPushDownEncounterUnsupportedJoinType() throws AnalysisException {
String sql = "select c_custkey"
+ " from (select lo_custkey from lineorder left outer join dates on lo_orderdate = d_datekey) a"
+ " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 3);
}
@Test
public void testPushDownThroughAggNode() throws AnalysisException {
String sql = "select profit"
+ " from (select lo_custkey, sum(lo_revenue - lo_supplycost) as profit from lineorder inner join dates"
+ " on lo_orderdate = d_datekey group by lo_custkey) a"
+ " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 3);
}
@Test
public void testDoNotPushDownThroughAggFunction() throws AnalysisException {
String sql = "select profit"
+ " from (select lo_custkey, sum(lo_revenue - lo_supplycost) as profit from lineorder inner join dates"
+ " on lo_orderdate = d_datekey group by lo_custkey) a"
+ " inner join (select sum(c_custkey) c_custkey from customer inner join supplier on c_custkey = s_suppkey group by s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 2);
}
@Test
public void testCrossJoin() throws AnalysisException {
String sql = "select c_custkey, lo_custkey from lineorder, customer where lo_custkey = c_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 1);
}
@Test
public void testSubQueryAlias() throws AnalysisException {
String sql = "select c_custkey, lo_custkey from lineorder l, customer c where c.c_custkey = l.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 1);
}
@Test
public void testView() throws Exception {
createView("create view if not exists v1 as \n"
+ " select * \n"
+ " from customer");
createView("create view if not exists v2 as\n"
+ " select *\n"
+ " from lineorder");
createView("create view if not exists v3 as \n"
+ " select *\n"
+ " from v1 join (\n"
+ " select *\n"
+ " from v2\n"
+ " ) t \n"
+ " on v1.c_custkey = t.lo_custkey");
String sql = "select * from (\n"
+ " select * \n"
+ " from part p \n"
+ " join v2 on p.p_partkey = v2.lo_partkey) t1 \n"
+ " join (\n"
+ " select * \n"
+ " from supplier s \n"
+ " join v3 on s.s_region = v3.c_region) t2 \n"
+ " on t1.p_partkey = t2.lo_partkey\n"
+ " order by t1.lo_custkey, t1.p_partkey, t2.s_suppkey, t2.c_custkey, t2.lo_orderkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 4);
}
@Test
public void testPushDownThroughJoin() throws AnalysisException {
String sql = "select c_custkey from (select c_custkey from (select lo_custkey from lineorder inner join dates"
+ " on lo_orderdate = d_datekey) a"
+ " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder"
+ " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 5);
}
@Test
public void testPushDownThroughUnsupportedJoinType() throws AnalysisException {
String sql = "select c_custkey from (select c_custkey from (select lo_custkey from lineorder inner join dates"
+ " on lo_orderdate = d_datekey) a"
+ " inner join (select c_custkey from customer left outer join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder"
+ " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 5);
}
private Optional<List<RuntimeFilter>> getRuntimeFilters(String sql) throws AnalysisException {
NereidsPlanner planner = new NereidsPlanner(createStatementCtx(sql));
PhysicalPlan plan = planner.plan(new NereidsParser().parseSingle(sql), PhysicalProperties.ANY);
System.out.println(plan.treeString());
PlanTranslatorContext context = new PlanTranslatorContext(planner.getCascadesContext());
PlanFragment root = new PhysicalPlanTranslator().translatePlan(plan, context);
System.out.println(root.getFragmentId());
if (context.getRuntimeTranslator().isPresent()) {
RuntimeFilterContext ctx = planner.getCascadesContext().getRuntimeFilterContext();
Assertions.assertEquals(ctx.getNereidsRuntimeFilter().size(), ctx.getLegacyFilters().size());
return Optional.of(ctx.getNereidsRuntimeFilter());
}
return Optional.empty();
}
private boolean checkRuntimeFilterExpr(RuntimeFilter filter, String srcColName, String targetColName) {
return filter.getSrcExpr().toSql().equals(srcColName)
&& filter.getTargetExpr().toSql().equals(targetColName);
}
}

View File

@ -99,12 +99,13 @@ public class ViewTest extends TestWithFeService implements PatternMatchSupported
NamedExpressionUtil.clear();
System.out.println("\n\n***** " + sql + " *****\n\n");
StatementContext statementContext = MemoTestUtils.createStatementContext(connectContext, sql);
PhysicalPlan plan = new NereidsPlanner(statementContext).plan(
NereidsPlanner planner = new NereidsPlanner(statementContext);
PhysicalPlan plan = planner.plan(
new NereidsParser().parseSingle(sql),
PhysicalProperties.ANY
);
// Just to check whether translate will throw exception
new PhysicalPlanTranslator().translatePlan(plan, new PlanTranslatorContext());
new PhysicalPlanTranslator().translatePlan(plan, new PlanTranslatorContext(planner.getCascadesContext()));
}
}

View File

@ -89,12 +89,13 @@ public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMat
for (String sql : testSql) {
NamedExpressionUtil.clear();
StatementContext statementContext = MemoTestUtils.createStatementContext(connectContext, sql);
PhysicalPlan plan = new NereidsPlanner(statementContext).plan(
NereidsPlanner planner = new NereidsPlanner(statementContext);
PhysicalPlan plan = planner.plan(
parser.parseSingle(sql),
PhysicalProperties.ANY
);
// Just to check whether translate will throw exception
new PhysicalPlanTranslator().translatePlan(plan, new PlanTranslatorContext());
new PhysicalPlanTranslator().translatePlan(plan, new PlanTranslatorContext(planner.getCascadesContext()));
}
}

View File

@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.trees.plans.Plan;
import java.util.stream.IntStream;
public class PhysicalPlanMatchingUtils {
public static <TYPE extends Plan> boolean topDownFindMatching(Plan plan, Pattern<TYPE> pattern) {
if (!pattern.matchRoot(plan)) {
return false;
}
if (!pattern.getPredicates().stream().allMatch(pred -> pred.test((TYPE) plan))) {
return false;
}
return IntStream.range(0, plan.children().size()).allMatch(i -> topDownFindMatching(plan.child(i), pattern.child(i)));
}
}

View File

@ -43,7 +43,7 @@ suite("view") {
select *
from v2
) t
on v1.c_custkey = t.lo_custkey;
on v1.c_custkey = t.lo_custkey
"""
sql "SET enable_fallback_to_original_planner=false"