[enhancement](Nereids) generate correct distribution spec after project (#13725)
after project, some Slot maybe project to another one. So we need to replace ExprId in DistributionSpecHash to the new one. if we do project other than Alias, We need to return DistributionSpecAny other than child's DistributionSpec.
This commit is contained in:
@ -20,6 +20,10 @@ package org.apache.doris.nereids.properties;
|
||||
import org.apache.doris.nereids.PlanContext;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
|
||||
@ -37,9 +41,14 @@ import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
|
||||
import org.apache.doris.nereids.util.JoinUtils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Used for property drive.
|
||||
@ -117,8 +126,37 @@ public class ChildOutputPropertyDeriver extends PlanVisitor<PhysicalProperties,
|
||||
|
||||
@Override
|
||||
public PhysicalProperties visitPhysicalProject(PhysicalProject<? extends Plan> project, PlanContext context) {
|
||||
// TODO: order spec do not process since we do not use it.
|
||||
Preconditions.checkState(childrenOutputProperties.size() == 1);
|
||||
return childrenOutputProperties.get(0);
|
||||
PhysicalProperties childProperties = childrenOutputProperties.get(0);
|
||||
DistributionSpec childDistributionSpec = childProperties.getDistributionSpec();
|
||||
OrderSpec childOrderSpec = childProperties.getOrderSpec();
|
||||
DistributionSpec outputDistributionSpec;
|
||||
if (childDistributionSpec instanceof DistributionSpecHash) {
|
||||
Map<ExprId, ExprId> projections = Maps.newHashMap();
|
||||
Set<ExprId> obstructions = Sets.newHashSet();
|
||||
for (NamedExpression namedExpression : project.getProjects()) {
|
||||
if (namedExpression instanceof Alias) {
|
||||
Alias alias = (Alias) namedExpression;
|
||||
if (alias.child() instanceof SlotReference) {
|
||||
projections.put(((SlotReference) alias.child()).getExprId(), alias.getExprId());
|
||||
} else {
|
||||
obstructions.addAll(
|
||||
alias.child().getInputSlots().stream()
|
||||
.map(NamedExpression::getExprId)
|
||||
.collect(Collectors.toSet()));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (projections.entrySet().stream().allMatch(kv -> kv.getKey().equals(kv.getValue()))) {
|
||||
return childrenOutputProperties.get(0);
|
||||
}
|
||||
outputDistributionSpec = ((DistributionSpecHash) childDistributionSpec).project(projections, obstructions);
|
||||
return new PhysicalProperties(outputDistributionSpec, childOrderSpec);
|
||||
} else {
|
||||
return childrenOutputProperties.get(0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -208,16 +208,50 @@ public class DistributionSpecHash extends DistributionSpec {
|
||||
equivalenceExprIds, exprIdToEquivalenceSet);
|
||||
}
|
||||
|
||||
/**
|
||||
* generate a new DistributionSpec after projection.
|
||||
*/
|
||||
public DistributionSpec project(Map<ExprId, ExprId> projections, Set<ExprId> obstructions) {
|
||||
List<ExprId> orderedShuffledColumns = Lists.newArrayList();
|
||||
List<Set<ExprId>> equivalenceExprIds = Lists.newArrayList();
|
||||
Map<ExprId, Integer> exprIdToEquivalenceSet = Maps.newHashMap();
|
||||
for (ExprId shuffledColumn : this.orderedShuffledColumns) {
|
||||
if (obstructions.contains(shuffledColumn)) {
|
||||
return DistributionSpecAny.INSTANCE;
|
||||
}
|
||||
orderedShuffledColumns.add(projections.getOrDefault(shuffledColumn, shuffledColumn));
|
||||
}
|
||||
for (Set<ExprId> equivalenceSet : this.equivalenceExprIds) {
|
||||
Set<ExprId> projectionEquivalenceSet = Sets.newHashSet();
|
||||
for (ExprId equivalence : equivalenceSet) {
|
||||
if (obstructions.contains(equivalence)) {
|
||||
return DistributionSpecAny.INSTANCE;
|
||||
}
|
||||
projectionEquivalenceSet.add(projections.getOrDefault(equivalence, equivalence));
|
||||
}
|
||||
equivalenceExprIds.add(projectionEquivalenceSet);
|
||||
}
|
||||
for (Map.Entry<ExprId, Integer> exprIdSetKV : this.exprIdToEquivalenceSet.entrySet()) {
|
||||
if (obstructions.contains(exprIdSetKV.getKey())) {
|
||||
return DistributionSpecAny.INSTANCE;
|
||||
}
|
||||
if (projections.containsKey(exprIdSetKV.getKey())) {
|
||||
exprIdToEquivalenceSet.put(projections.get(exprIdSetKV.getKey()), exprIdSetKV.getValue());
|
||||
} else {
|
||||
exprIdToEquivalenceSet.put(exprIdSetKV.getKey(), exprIdSetKV.getValue());
|
||||
}
|
||||
}
|
||||
return new DistributionSpecHash(orderedShuffledColumns, shuffleType, tableId, partitionIds,
|
||||
equivalenceExprIds, exprIdToEquivalenceSet);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (!super.equals(o)) {
|
||||
return false;
|
||||
}
|
||||
DistributionSpecHash that = (DistributionSpecHash) o;
|
||||
//TODO: that.orderedShuffledColumns may have equivalent slots. This will be done later
|
||||
return shuffleType == that.shuffleType
|
||||
&& orderedShuffledColumns.size() == that.orderedShuffledColumns.size()
|
||||
&& equalsSatisfy(that.orderedShuffledColumns);
|
||||
return shuffleType == that.shuffleType && orderedShuffledColumns.equals(that.orderedShuffledColumns);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -32,6 +32,7 @@ import java.util.Objects;
|
||||
* Spec of sort order.
|
||||
*/
|
||||
public class OrderSpec {
|
||||
// TODO: use a OrderKey with ExprId list to instead of current orderKeys for easy to use.
|
||||
private final List<OrderKey> orderKeys;
|
||||
|
||||
public OrderSpec() {
|
||||
|
||||
@ -27,6 +27,7 @@ import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public class DistributionSpecHashTest {
|
||||
|
||||
@ -79,6 +80,49 @@ public class DistributionSpecHashTest {
|
||||
Assertions.assertEquals(expected, DistributionSpecHash.merge(natural, join));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testProject() {
|
||||
Map<ExprId, Integer> naturalMap = Maps.newHashMap();
|
||||
naturalMap.put(new ExprId(0), 0);
|
||||
naturalMap.put(new ExprId(1), 0);
|
||||
naturalMap.put(new ExprId(2), 1);
|
||||
naturalMap.put(new ExprId(3), 1);
|
||||
|
||||
DistributionSpecHash original = new DistributionSpecHash(
|
||||
Lists.newArrayList(new ExprId(0), new ExprId(2)),
|
||||
ShuffleType.NATURAL,
|
||||
0,
|
||||
Sets.newHashSet(0L),
|
||||
Lists.newArrayList(Sets.newHashSet(new ExprId(0), new ExprId(1)), Sets.newHashSet(new ExprId(2), new ExprId(3))),
|
||||
naturalMap
|
||||
);
|
||||
|
||||
Map<ExprId, ExprId> projects = Maps.newHashMap();
|
||||
projects.put(new ExprId(2), new ExprId(5));
|
||||
Set<ExprId> obstructions = Sets.newHashSet();
|
||||
|
||||
DistributionSpec after = original.project(projects, obstructions);
|
||||
Assertions.assertTrue(after instanceof DistributionSpecHash);
|
||||
DistributionSpecHash afterHash = (DistributionSpecHash) after;
|
||||
Assertions.assertEquals(Lists.newArrayList(new ExprId(0), new ExprId(5)), afterHash.getOrderedShuffledColumns());
|
||||
Assertions.assertEquals(
|
||||
Lists.newArrayList(
|
||||
Sets.newHashSet(new ExprId(0), new ExprId(1)),
|
||||
Sets.newHashSet(new ExprId(5), new ExprId(3))),
|
||||
afterHash.getEquivalenceExprIds());
|
||||
Map<ExprId, Integer> actualMap = Maps.newHashMap();
|
||||
actualMap.put(new ExprId(0), 0);
|
||||
actualMap.put(new ExprId(1), 0);
|
||||
actualMap.put(new ExprId(5), 1);
|
||||
actualMap.put(new ExprId(3), 1);
|
||||
Assertions.assertEquals(actualMap, afterHash.getExprIdToEquivalenceSet());
|
||||
|
||||
// have obstructions
|
||||
obstructions.add(new ExprId(3));
|
||||
after = original.project(projects, obstructions);
|
||||
Assertions.assertTrue(after instanceof DistributionSpecAny);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSatisfyAny() {
|
||||
DistributionSpec required = DistributionSpecAny.INSTANCE;
|
||||
|
||||
Reference in New Issue
Block a user