[fix](nereids) fix bugs in nereids window function (#17284)
fix two problems: 1. push agg-fun in windowExpression down to AggregateNode for example, sql: select sum(sum(a)) over (order by b) Plan: windowExpression( sum(y) over (order by b)) +--- Agg(sum(a) as y, b) 2. push other expr to upper proj for example, sql: select sum(a+1) over () Plan: windowExpression(sum(y) over ()) +--- Project(a + 1 as y,...) +--- Agg(a,...)
This commit is contained in:
@ -107,9 +107,10 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
|
||||
Expr staleExpr = expression.accept(INSTANCE, context);
|
||||
try {
|
||||
staleExpr.finalizeForNereids();
|
||||
} catch (org.apache.doris.common.AnalysisException e) {
|
||||
} catch (Exception e) {
|
||||
throw new AnalysisException(
|
||||
"Translate Nereids expression to stale expression failed. " + e.getMessage(), e);
|
||||
"Translate Nereids expression `" + expression.toSql()
|
||||
+ "` to stale expression failed. " + e.getMessage(), e);
|
||||
}
|
||||
return staleExpr;
|
||||
}
|
||||
|
||||
@ -64,10 +64,10 @@ public class CheckAfterRewrite extends OneAnalysisRuleFactory {
|
||||
.collect(Collectors.toSet());
|
||||
notFromChildren = removeValidSlotsNotFromChildren(notFromChildren, childrenOutput);
|
||||
if (!notFromChildren.isEmpty()) {
|
||||
throw new AnalysisException(String.format("Input slot(s) not in child's output: %s",
|
||||
throw new AnalysisException(String.format("Input slot(s) not in child's output: %s in plan: %s",
|
||||
StringUtils.join(notFromChildren.stream()
|
||||
.map(ExpressionTrait::toSql)
|
||||
.collect(Collectors.toSet()), ", ")));
|
||||
.collect(Collectors.toSet()), ", "), plan));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.WindowExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
@ -38,6 +39,8 @@ import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* extract window expressions from LogicalProject.projects and Normalize LogicalWindow
|
||||
@ -94,14 +97,44 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i
|
||||
// bottomProjects includes:
|
||||
// 1. expressions from function and WindowSpec's partitionKeys and orderKeys
|
||||
// 2. other slots of outputExpressions
|
||||
/*
|
||||
avg(c) / sum(a+1) over (order by avg(b)) group by a
|
||||
win(x/sum(z) over y)
|
||||
prj(x, y, a+1 as z)
|
||||
agg(avg(c) x, avg(b) y, a)
|
||||
proj(a b c)
|
||||
toBePushDown = {avg(c), a+1, avg(b)}
|
||||
*/
|
||||
return expressions.stream()
|
||||
.flatMap(expression -> {
|
||||
if (expression.anyMatch(WindowExpression.class::isInstance)) {
|
||||
Set<Slot> inputSlots = expression.getInputSlots().stream().collect(Collectors.toSet());
|
||||
Set<WindowExpression> collects = expression.collect(WindowExpression.class::isInstance);
|
||||
return collects.stream().flatMap(windowExpression ->
|
||||
windowExpression.getExpressionsInWindowSpec().stream()
|
||||
// constant arguments may in WindowFunctions(e.g. Lead, Lag), which shouldn't be pushed down
|
||||
.filter(expr -> !expr.isConstant())
|
||||
Set<Slot> windowInputSlots = collects.stream().flatMap(
|
||||
win -> win.getInputSlots().stream()
|
||||
).collect(Collectors.toSet());
|
||||
/*
|
||||
substr(
|
||||
ref_1.cp_type,
|
||||
max(
|
||||
cast(ref_1.`cp_catalog_page_number` as int)) over (...)
|
||||
),
|
||||
1)
|
||||
|
||||
in above case, ref_1.cp_type should be pushed down. ref_1.cp_type is in
|
||||
substr.inputSlots, but not in windowExpression.inputSlots
|
||||
|
||||
inputSlots= {ref_1.cp_type}
|
||||
*/
|
||||
inputSlots.removeAll(windowInputSlots);
|
||||
return Stream.concat(
|
||||
collects.stream().flatMap(windowExpression ->
|
||||
windowExpression.getExpressionsInWindowSpec().stream()
|
||||
// constant arguments may in WindowFunctions(e.g. Lead, Lag)
|
||||
// which shouldn't be pushed down
|
||||
.filter(expr -> !expr.isConstant())
|
||||
),
|
||||
inputSlots.stream()
|
||||
);
|
||||
}
|
||||
return ImmutableList.of(expression).stream();
|
||||
|
||||
@ -34,10 +34,13 @@ import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* normalize aggregate's group keys and AggregateFunction's child to SlotReference
|
||||
@ -57,6 +60,31 @@ import java.util.stream.Collectors;
|
||||
* +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10, Alias(SUM(v1#3 + 1))#11])
|
||||
* +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
|
||||
* <p>
|
||||
*
|
||||
* Note: window function will be moved to upper project
|
||||
* all agg functions except the top agg should be pushed to Aggregate node.
|
||||
* example 1:
|
||||
* select min(x), sum(x) over () ...
|
||||
* the 'sum(x)' is top agg of window function, it should be moved to upper project
|
||||
* plan:
|
||||
* project(sum(x) over())
|
||||
* Aggregate(min(x), x)
|
||||
*
|
||||
* example 2:
|
||||
* select min(x), avg(sum(x)) over() ...
|
||||
* the 'sum(x)' should be moved to Aggregate
|
||||
* plan:
|
||||
* project(avg(y) over())
|
||||
* Aggregate(min(x), sum(x) as y)
|
||||
* example 3:
|
||||
* select sum(x+1), x+1, sum(x+1) over() ...
|
||||
* window function should use x instead of x+1
|
||||
* plan:
|
||||
* project(sum(x+1) over())
|
||||
* Agg(sum(y), x)
|
||||
* project(x+1 as y)
|
||||
*
|
||||
*
|
||||
* More example could get from UT {NormalizeAggregateTest}
|
||||
*/
|
||||
public class NormalizeAggregate extends OneRewriteRuleFactory implements NormalizeToSlot {
|
||||
@ -64,8 +92,37 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
public Rule build() {
|
||||
return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
|
||||
// push expression to bottom project
|
||||
Set<Alias> existsAliases = ExpressionUtils.collect(
|
||||
Set<Alias> existsAliases = ExpressionUtils.mutableCollect(
|
||||
aggregate.getOutputExpressions(), Alias.class::isInstance);
|
||||
Set<AggregateFunction> aggregateFunctionsInWindow = collectAggregateFunctionsInWindow(
|
||||
aggregate.getOutputExpressions());
|
||||
Set<Expression> existsAggAlias = existsAliases.stream().map(alias -> alias.child())
|
||||
.filter(AggregateFunction.class::isInstance)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
/*
|
||||
* agg-functions inside window function is regarded as an output of aggregate.
|
||||
* select sum(avg(c)) over ...
|
||||
* is regarded as
|
||||
* select avg(c), sum(avg(c)) over ...
|
||||
*
|
||||
* the plan:
|
||||
* project(sum(y) over)
|
||||
* Aggregate(avg(c) as y)
|
||||
*
|
||||
* after Aggregate, the 'y' is removed by upper project.
|
||||
*
|
||||
* aliasOfAggFunInWindowUsedAsAggOutput = {alias(avg(c))}
|
||||
*/
|
||||
List<Alias> aliasOfAggFunInWindowUsedAsAggOutput = Lists.newArrayList();
|
||||
|
||||
for (AggregateFunction aggFun : aggregateFunctionsInWindow) {
|
||||
if (!existsAggAlias.contains(aggFun)) {
|
||||
Alias alias = new Alias(aggFun, aggFun.toSql());
|
||||
existsAliases.add(alias);
|
||||
aliasOfAggFunInWindowUsedAsAggOutput.add(alias);
|
||||
}
|
||||
}
|
||||
Set<Expression> needToSlots = collectGroupByAndArgumentsOfAggregateFunctions(aggregate);
|
||||
NormalizeToSlotContext groupByAndArgumentToSlotContext =
|
||||
NormalizeToSlotContext.buildContext(existsAliases, needToSlots);
|
||||
@ -80,13 +137,22 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
// replace groupBy and arguments of aggregate function to slot, may be this output contains
|
||||
// some expression on the aggregate functions, e.g. `sum(value) + 1`, we should replace
|
||||
// the sum(value) to slot and move the `slot + 1` to the upper project later.
|
||||
List<NamedExpression> normalizeOutputPhase1 = aggregate.getOutputExpressions().stream()
|
||||
.map(expr -> {
|
||||
if (expr.anyMatch(WindowExpression.class::isInstance)) {
|
||||
return expr;
|
||||
}
|
||||
return groupByAndArgumentToSlotContext.normalizeToUseSlotRef(expr);
|
||||
}).collect(Collectors.toList());
|
||||
List<NamedExpression> normalizeOutputPhase1 = Stream.concat(
|
||||
aggregate.getOutputExpressions().stream(),
|
||||
aliasOfAggFunInWindowUsedAsAggOutput.stream())
|
||||
.map(expr -> groupByAndArgumentToSlotContext
|
||||
.normalizeToUseSlotRefUp(expr, WindowExpression.class::isInstance))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Set<Slot> windowInputSlots = collectWindowInputSlots(aggregate.getOutputExpressions());
|
||||
Set<Expression> itemsInWindow = Sets.newHashSet(windowInputSlots);
|
||||
itemsInWindow.addAll(aggregateFunctionsInWindow);
|
||||
NormalizeToSlotContext windowToSlotContext =
|
||||
NormalizeToSlotContext.buildContext(existsAliases, itemsInWindow);
|
||||
normalizeOutputPhase1 = normalizeOutputPhase1.stream()
|
||||
.map(expr -> windowToSlotContext
|
||||
.normalizeToUseSlotRefDown(expr, WindowExpression.class::isInstance, true))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Set<AggregateFunction> normalizedAggregateFunctions =
|
||||
collectNonWindowedAggregateFunctions(normalizeOutputPhase1);
|
||||
@ -115,14 +181,10 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
LogicalAggregate<Plan> normalizedAggregate = aggregate.withNormalized(
|
||||
(List) normalizedGroupBy, normalizedAggregateOutput, normalizedChild);
|
||||
|
||||
normalizeOutputPhase1.removeAll(aliasOfAggFunInWindowUsedAsAggOutput);
|
||||
// exclude same-name functions in WindowExpression
|
||||
List<NamedExpression> upperProjects = normalizeOutputPhase1.stream()
|
||||
.map(expr -> {
|
||||
if (expr.anyMatch(WindowExpression.class::isInstance)) {
|
||||
return expr;
|
||||
}
|
||||
return aggregateFunctionToSlotContext.normalizeToUseSlotRef(expr);
|
||||
}).collect(Collectors.toList());
|
||||
.map(aggregateFunctionToSlotContext::normalizeToUseSlotRef).collect(Collectors.toList());
|
||||
return new LogicalProject<>(upperProjects, normalizedAggregate);
|
||||
}).toRule(RuleType.NORMALIZE_AGGREGATE);
|
||||
}
|
||||
@ -146,21 +208,61 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
}))
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
|
||||
Set<Expression> windowFunctionKeys = collectWindowFunctionKeys(aggregate.getOutputExpressions());
|
||||
|
||||
ImmutableSet<Expression> needPushDown = ImmutableSet.<Expression>builder()
|
||||
// group by should be pushed down, e.g. group by (k + 1),
|
||||
// we should push down the `k + 1` to the bottom plan
|
||||
.addAll(groupingByExpr)
|
||||
// e.g. sum(k + 1), we should push down the `k + 1` to the bottom plan
|
||||
.addAll(argumentsOfAggregateFunction)
|
||||
.addAll(windowFunctionKeys)
|
||||
.build();
|
||||
return needPushDown;
|
||||
}
|
||||
|
||||
private Set<AggregateFunction> collectNonWindowedAggregateFunctions(List<NamedExpression> aggOutput) {
|
||||
List<Expression> expressionsWithoutWindow = aggOutput.stream()
|
||||
.filter(expr -> !expr.anyMatch(WindowExpression.class::isInstance))
|
||||
.collect(Collectors.toList());
|
||||
private Set<Expression> collectWindowFunctionKeys(List<NamedExpression> aggOutput) {
|
||||
Set<Expression> windowInputs = Sets.newHashSet();
|
||||
for (Expression expr : aggOutput) {
|
||||
Set<WindowExpression> windows = expr.collect(WindowExpression.class::isInstance);
|
||||
for (WindowExpression win : windows) {
|
||||
windowInputs.addAll(win.getPartitionKeys().stream().flatMap(pk -> pk.getInputSlots().stream()).collect(
|
||||
Collectors.toList()));
|
||||
windowInputs.addAll(win.getOrderKeys().stream().flatMap(ok -> ok.getInputSlots().stream()).collect(
|
||||
Collectors.toList()));
|
||||
}
|
||||
}
|
||||
return windowInputs;
|
||||
}
|
||||
|
||||
return ExpressionUtils.collect(expressionsWithoutWindow, AggregateFunction.class::isInstance);
|
||||
/**
|
||||
* select sum(c2), avg(min(c2)) over (partition by max(c1) order by count(c1)) from T ...
|
||||
* extract {sum, min, max, count}. avg is not extracted.
|
||||
*/
|
||||
private Set<AggregateFunction> collectNonWindowedAggregateFunctions(List<NamedExpression> aggOutput) {
|
||||
return ExpressionUtils.collect(aggOutput, expr -> {
|
||||
if (expr instanceof AggregateFunction) {
|
||||
return !((AggregateFunction) expr).isWindowFunction();
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
private Set<AggregateFunction> collectAggregateFunctionsInWindow(List<NamedExpression> aggOutput) {
|
||||
|
||||
List<WindowExpression> windows = Lists.newArrayList(
|
||||
ExpressionUtils.collect(aggOutput, WindowExpression.class::isInstance));
|
||||
return ExpressionUtils.collect(windows, expr -> {
|
||||
if (expr instanceof AggregateFunction) {
|
||||
return !((AggregateFunction) expr).isWindowFunction();
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
private Set<Slot> collectWindowInputSlots(List<NamedExpression> aggOutput) {
|
||||
List<WindowExpression> windows = Lists.newArrayList(
|
||||
ExpressionUtils.collect(aggOutput, WindowExpression.class::isInstance));
|
||||
return windows.stream().flatMap(win -> win.getInputSlots().stream()).collect(Collectors.toSet());
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,6 +31,7 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Predicate;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/** NormalizeToSlot */
|
||||
@ -88,6 +89,24 @@ public interface NormalizeToSlot {
|
||||
})).collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
public <E extends Expression> E normalizeToUseSlotRefUp(E expression, Predicate skip) {
|
||||
return (E) expression.rewriteDownShortCircuitUp(child -> {
|
||||
NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child);
|
||||
return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr;
|
||||
}, skip);
|
||||
}
|
||||
|
||||
/**
|
||||
* rewrite subtrees whose root matches predicate border
|
||||
* when we traverse to the node satisfies border predicate, aboveBorder becomes false
|
||||
*/
|
||||
public <E extends Expression> E normalizeToUseSlotRefDown(E expression, Predicate border, boolean aboveBorder) {
|
||||
return (E) expression.rewriteDownShortCircuitDown(child -> {
|
||||
NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child);
|
||||
return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr;
|
||||
}, border, aboveBorder);
|
||||
}
|
||||
|
||||
/**
|
||||
* generate bottom projections with groupByExpressions.
|
||||
* eg:
|
||||
|
||||
@ -96,6 +96,64 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> {
|
||||
return currentNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* same as rewriteDownShortCircuit,
|
||||
* except that subtrees, whose root satisfies predicate is satisfied, are not rewritten
|
||||
*/
|
||||
default NODE_TYPE rewriteDownShortCircuitUp(Function<NODE_TYPE, NODE_TYPE> rewriteFunction, Predicate skip) {
|
||||
NODE_TYPE currentNode = rewriteFunction.apply((NODE_TYPE) this);
|
||||
if (skip.test(currentNode)) {
|
||||
return currentNode;
|
||||
}
|
||||
if (currentNode == this) {
|
||||
Builder<NODE_TYPE> newChildren = ImmutableList.builderWithExpectedSize(arity());
|
||||
boolean changed = false;
|
||||
for (NODE_TYPE child : children()) {
|
||||
NODE_TYPE newChild = child.rewriteDownShortCircuitUp(rewriteFunction, skip);
|
||||
if (child != newChild) {
|
||||
changed = true;
|
||||
}
|
||||
newChildren.add(newChild);
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
currentNode = currentNode.withChildren(newChildren.build());
|
||||
}
|
||||
}
|
||||
return currentNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* similar to rewriteDownShortCircuit, except that only subtrees, whose root satisfies
|
||||
* border predicate are rewritten.
|
||||
*/
|
||||
default NODE_TYPE rewriteDownShortCircuitDown(Function<NODE_TYPE, NODE_TYPE> rewriteFunction,
|
||||
Predicate border, boolean aboveBorder) {
|
||||
NODE_TYPE currentNode = (NODE_TYPE) this;
|
||||
if (border.test(this)) {
|
||||
aboveBorder = false;
|
||||
}
|
||||
if (!aboveBorder) {
|
||||
currentNode = rewriteFunction.apply((NODE_TYPE) this);
|
||||
}
|
||||
if (currentNode == this) {
|
||||
Builder<NODE_TYPE> newChildren = ImmutableList.builderWithExpectedSize(arity());
|
||||
boolean changed = false;
|
||||
for (NODE_TYPE child : children()) {
|
||||
NODE_TYPE newChild = child.rewriteDownShortCircuitDown(rewriteFunction, border, aboveBorder);
|
||||
if (child != newChild) {
|
||||
changed = true;
|
||||
}
|
||||
newChildren.add(newChild);
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
currentNode = currentNode.withChildren(newChildren.build());
|
||||
}
|
||||
}
|
||||
return currentNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* bottom-up rewrite.
|
||||
* @param rewriteFunction rewrite function.
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.UnboundException;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
|
||||
@ -53,6 +54,9 @@ public class WindowExpression extends Expression {
|
||||
.addAll(orderKeys)
|
||||
.build().toArray(new Expression[0]));
|
||||
this.function = function;
|
||||
if (function instanceof AggregateFunction) {
|
||||
((AggregateFunction) function).setWindowFunction(true);
|
||||
}
|
||||
this.partitionKeys = ImmutableList.copyOf(partitionKeys);
|
||||
this.orderKeys = ImmutableList.copyOf(orderKeys);
|
||||
this.windowFrame = Optional.empty();
|
||||
@ -68,6 +72,9 @@ public class WindowExpression extends Expression {
|
||||
.add(windowFrame)
|
||||
.build().toArray(new Expression[0]));
|
||||
this.function = function;
|
||||
if (function instanceof AggregateFunction) {
|
||||
((AggregateFunction) function).setWindowFunction(true);
|
||||
}
|
||||
this.partitionKeys = ImmutableList.copyOf(partitionKeys);
|
||||
this.orderKeys = ImmutableList.copyOf(orderKeys);
|
||||
this.windowFrame = Optional.of(Objects.requireNonNull(windowFrame));
|
||||
|
||||
@ -38,6 +38,7 @@ import java.util.stream.Collectors;
|
||||
public abstract class AggregateFunction extends BoundFunction implements ExpectsInputTypes {
|
||||
|
||||
protected final boolean distinct;
|
||||
protected boolean isWindowFunction = false;
|
||||
|
||||
public AggregateFunction(String name, Expression... arguments) {
|
||||
this(name, false, arguments);
|
||||
@ -77,6 +78,14 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
|
||||
return distinct;
|
||||
}
|
||||
|
||||
public boolean isWindowFunction() {
|
||||
return isWindowFunction;
|
||||
}
|
||||
|
||||
public void setWindowFunction(boolean windowFunction) {
|
||||
isWindowFunction = windowFunction;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
@ -86,7 +95,8 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
|
||||
return false;
|
||||
}
|
||||
AggregateFunction that = (AggregateFunction) o;
|
||||
return Objects.equals(distinct, that.distinct)
|
||||
return isWindowFunction == that.isWindowFunction
|
||||
&& Objects.equals(distinct, that.distinct)
|
||||
&& Objects.equals(getName(), that.getName())
|
||||
&& Objects.equals(children, that.children);
|
||||
}
|
||||
@ -123,4 +133,5 @@ public abstract class AggregateFunction extends BoundFunction implements Expects
|
||||
.collect(Collectors.joining(", "));
|
||||
return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")";
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -411,6 +411,13 @@ public class ExpressionUtils {
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
|
||||
public static <E> Set<E> mutableCollect(List<? extends Expression> expressions,
|
||||
Predicate<TreeNode<Expression>> predicate) {
|
||||
return expressions.stream()
|
||||
.flatMap(expr -> expr.<Set<E>>collect(predicate).stream())
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
public static <E> List<E> collectAll(List<? extends Expression> expressions,
|
||||
Predicate<TreeNode<Expression>> predicate) {
|
||||
return expressions.stream()
|
||||
|
||||
@ -106,7 +106,11 @@ public class AnalyticEvalNode extends PlanNode {
|
||||
AnalyticWindow analyticWindow, TupleDescriptor intermediateTupleDesc,
|
||||
TupleDescriptor outputTupleDesc, Expr partitionByEq, Expr orderByEq,
|
||||
TupleDescriptor bufferedTupleDesc) {
|
||||
super(id, input.getTupleIds(), "ANALYTIC", StatisticalType.ANALYTIC_EVAL_NODE);
|
||||
super(id,
|
||||
(input.getOutputTupleDesc() != null
|
||||
? Lists.newArrayList(input.getOutputTupleDesc().getId()) :
|
||||
input.getTupleIds()),
|
||||
"ANALYTIC", StatisticalType.ANALYTIC_EVAL_NODE);
|
||||
Preconditions.checkState(!tupleIds.contains(outputTupleDesc.getId()));
|
||||
// we're materializing the input row augmented with the analytic output tuple
|
||||
tupleIds.add(outputTupleDesc.getId());
|
||||
|
||||
Reference in New Issue
Block a user