[fix](Nereids) storage later agg rule process agg children by mistake (#26101)
update Project#findProject agg function's children could be any expression rather than only slot. we use Project#findProject to process them. But this util could only process slot. This PR update this util to let it could process all type expression.
This commit is contained in:
@ -397,7 +397,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
|
||||
if (project != null) {
|
||||
argumentsOfAggregateFunction = Project.findProject(
|
||||
(List<SlotReference>) (List) argumentsOfAggregateFunction, project.getProjects())
|
||||
argumentsOfAggregateFunction, project.getProjects())
|
||||
.stream()
|
||||
.map(p -> p instanceof Alias ? p.child(0) : p)
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
@ -431,8 +431,8 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
Set<SlotReference> aggUsedSlots =
|
||||
ExpressionUtils.collect(argumentsOfAggregateFunction, SlotReference.class::isInstance);
|
||||
|
||||
List<SlotReference> usedSlotInTable = (List<SlotReference>) (List) Project.findProject(aggUsedSlots,
|
||||
(List<NamedExpression>) (List) logicalScan.getOutput());
|
||||
List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots,
|
||||
logicalScan.getOutput());
|
||||
|
||||
for (SlotReference slot : usedSlotInTable) {
|
||||
Column column = slot.getColumn().get();
|
||||
|
||||
@ -41,8 +41,8 @@ public class EliminateAggregate extends OneRewriteRuleFactory {
|
||||
if (!onlyHasSlots(outerAgg.getOutputExpressions())) {
|
||||
return outerAgg;
|
||||
}
|
||||
List<NamedExpression> prunedInnerAggOutput = Project.findProject(outerAgg.getOutputSet(),
|
||||
innerAgg.getOutputExpressions());
|
||||
List<NamedExpression> prunedInnerAggOutput = (List<NamedExpression>) Project.findProject(
|
||||
outerAgg.getOutputSet(), innerAgg.getOutputExpressions());
|
||||
return innerAgg.withAggOutput(prunedInnerAggOutput);
|
||||
}).toRule(RuleType.ELIMINATE_AGGREGATE);
|
||||
}
|
||||
|
||||
@ -23,9 +23,9 @@ import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
|
||||
import java.util.Collection;
|
||||
@ -78,22 +78,25 @@ public interface Project {
|
||||
/**
|
||||
* find projects, if not found the slot, then throw AnalysisException
|
||||
*/
|
||||
static List<NamedExpression> findProject(
|
||||
Collection<? extends Slot> slotReferences,
|
||||
List<NamedExpression> projects) throws AnalysisException {
|
||||
static List<? extends Expression> findProject(
|
||||
Collection<? extends Expression> expressions,
|
||||
List<? extends NamedExpression> projects) throws AnalysisException {
|
||||
Map<ExprId, NamedExpression> exprIdToProject = projects.stream()
|
||||
.collect(ImmutableMap.toImmutableMap(p -> p.getExprId(), p -> p));
|
||||
.collect(ImmutableMap.toImmutableMap(NamedExpression::getExprId, p -> p));
|
||||
|
||||
return slotReferences.stream()
|
||||
.map(slot -> {
|
||||
ExprId exprId = slot.getExprId();
|
||||
NamedExpression project = exprIdToProject.get(exprId);
|
||||
if (project == null) {
|
||||
throw new AnalysisException("ExprId " + slot.getExprId() + " no exists in " + projects);
|
||||
return ExpressionUtils.rewriteDownShortCircuit(expressions,
|
||||
expr -> {
|
||||
if (expr instanceof Slot) {
|
||||
Slot slot = (Slot) expr;
|
||||
ExprId exprId = slot.getExprId();
|
||||
NamedExpression project = exprIdToProject.get(exprId);
|
||||
if (project == null) {
|
||||
throw new AnalysisException("ExprId " + slot.getExprId() + " no exists in " + projects);
|
||||
}
|
||||
return project;
|
||||
}
|
||||
return project;
|
||||
})
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
return expr;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -282,7 +282,7 @@ public class ExpressionUtils {
|
||||
}
|
||||
|
||||
public static <E extends Expression> List<E> rewriteDownShortCircuit(
|
||||
List<E> exprs, Function<Expression, Expression> rewriteFunction) {
|
||||
Collection<E> exprs, Function<Expression, Expression> rewriteFunction) {
|
||||
return exprs.stream()
|
||||
.map(expr -> (E) expr.rewriteDownShortCircuit(rewriteFunction))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
|
||||
Reference in New Issue
Block a user