[Bug][Function] pass intermediate argument list to be (#10650)

This commit is contained in:
Pxl
2022-07-08 20:50:05 +08:00
committed by GitHub
parent 6f29a8ac0d
commit f58a071605
11 changed files with 52 additions and 43 deletions

View File

@ -491,17 +491,10 @@ public final class AggregateInfo extends AggregateInfoBase {
for (int i = 0; i < getAggregateExprs().size(); ++i) {
FunctionCallExpr inputExpr = getAggregateExprs().get(i);
Preconditions.checkState(inputExpr.isAggregateFunction());
List<Expr> paramExprs = new ArrayList<>();
// TODO(zhannngchen), change intermediate argument to a list, and remove this
// ad-hoc logic
if (inputExpr.fn.functionName().equals("max_by")
|| inputExpr.fn.functionName().equals("min_by")) {
paramExprs.addAll(inputExpr.getFnParams().exprs());
} else {
paramExprs.add(new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size())));
}
Expr aggExprParam =
new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size()));
FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
inputExpr, paramExprs);
inputExpr, Lists.newArrayList(aggExprParam), inputExpr.getFnParams().exprs());
aggExpr.analyzeNoThrow(analyzer);
aggExprs.add(aggExpr);
}
@ -623,7 +616,7 @@ public final class AggregateInfo extends AggregateInfoBase {
Expr aggExprParam =
new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size()));
FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
inputExpr, Lists.newArrayList(aggExprParam));
inputExpr, Lists.newArrayList(aggExprParam), inputExpr.getFnParams().exprs());
secondPhaseAggExprs.add(aggExpr);
}
Preconditions.checkState(

View File

@ -37,7 +37,6 @@ import org.apache.doris.common.ErrorReport;
import org.apache.doris.common.util.VectorizedUtil;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TAggregateExpr;
import org.apache.doris.thrift.TExprNode;
import org.apache.doris.thrift.TExprNodeType;
@ -69,6 +68,9 @@ public class FunctionCallExpr extends Expr {
// private BuiltinAggregateFunction.Operator aggOp;
private FunctionParams fnParams;
// represent original parament from aggregate function
private FunctionParams aggFnParams;
// check analytic function
private boolean isAnalyticFnCall = false;
// check table function
@ -92,6 +94,10 @@ public class FunctionCallExpr extends Expr {
private boolean isRewrote = false;
public void setAggFnParams(FunctionParams aggFnParams) {
this.aggFnParams = aggFnParams;
}
public void setIsAnalyticFnCall(boolean v) {
isAnalyticFnCall = v;
}
@ -153,6 +159,7 @@ public class FunctionCallExpr extends Expr {
// aggOp = e.aggOp;
isAnalyticFnCall = e.isAnalyticFnCall;
fnParams = params;
aggFnParams = e.aggFnParams;
// Just inherit the function object from 'e'.
fn = e.fn;
this.isMergeAggFn = e.isMergeAggFn;
@ -175,6 +182,7 @@ public class FunctionCallExpr extends Expr {
} else {
fnParams = new FunctionParams(other.fnParams.isDistinct(), children);
}
aggFnParams = other.aggFnParams;
this.isMergeAggFn = other.isMergeAggFn;
fn = other.fn;
this.isTableFnCall = other.isTableFnCall;
@ -428,9 +436,10 @@ public class FunctionCallExpr extends Expr {
// except in test cases that do it explicitly.
if (isAggregate() || isAnalyticFnCall) {
msg.node_type = TExprNodeType.AGG_EXPR;
if (!isAnalyticFnCall) {
msg.setAggExpr(new TAggregateExpr(isMergeAggFn));
if (aggFnParams == null) {
aggFnParams = fnParams;
}
msg.setAggExpr(aggFnParams.createTAggregateExpr(isMergeAggFn));
} else {
msg.node_type = TExprNodeType.FUNCTION_CALL;
}
@ -1143,14 +1152,15 @@ public class FunctionCallExpr extends Expr {
}
public static FunctionCallExpr createMergeAggCall(
FunctionCallExpr agg, List<Expr> params) {
FunctionCallExpr agg, List<Expr> intermediateParams, List<Expr> realParams) {
Preconditions.checkState(agg.isAnalyzed);
Preconditions.checkState(agg.isAggregateFunction());
FunctionCallExpr result = new FunctionCallExpr(
agg.fnName, new FunctionParams(false, params), true);
agg.fnName, new FunctionParams(false, intermediateParams), true);
// Inherit the function object from 'agg'.
result.fn = agg.fn;
result.type = agg.type;
result.setAggFnParams(new FunctionParams(false, realParams));
return result;
}

View File

@ -21,12 +21,15 @@
package org.apache.doris.analysis;
import org.apache.doris.common.io.Writable;
import org.apache.doris.thrift.TAggregateExpr;
import org.apache.doris.thrift.TTypeDesc;
import com.google.common.collect.Lists;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
@ -62,6 +65,18 @@ public class FunctionParams implements Writable {
return new FunctionParams();
}
public TAggregateExpr createTAggregateExpr(boolean isMergeAggFn) {
List<TTypeDesc> paramTypes = new ArrayList<TTypeDesc>();
if (exprs != null) {
for (Expr expr : exprs) {
TTypeDesc desc = expr.getType().toThrift();
desc.setIsNullable(expr.isNullable());
paramTypes.add(desc);
}
}
return new TAggregateExpr(isMergeAggFn, paramTypes);
}
public boolean isStar() {
return isStar;
}