[opt](Nereids) group_concat to support more cases (#15815)

enhance group_concat to support group_concat(cast(slot), ...) and support call it with 1 argument.
This commit is contained in:
mch_ucchi
2023-01-13 00:41:13 +08:00
committed by GitHub
parent 9d41994c17
commit a7af869bfd
4 changed files with 35 additions and 4 deletions

View File

@ -48,6 +48,10 @@ public class OrderExpression extends Expression implements UnaryExpression, Prop
return orderKey.isNullFirst();
}
public OrderKey getOrderKey() {
return orderKey;
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitOrderExpression(this, context);

View File

@ -33,6 +33,7 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
/**
* AggregateFunction 'group_concat'. This class is generated by GenerateFunction.
@ -41,13 +42,14 @@ public class GroupConcat extends AggregateFunction
implements ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.varArgs(VarcharType.SYSTEM_DEFAULT, AnyDataType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.varArgs(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT, AnyDataType.INSTANCE)
);
private int nonOrderArguments;
private final int nonOrderArguments;
/**
* constructor with 1 argument.
@ -103,8 +105,10 @@ public class GroupConcat extends AggregateFunction
*/
@Override
public GroupConcat withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children().size() > 1);
Preconditions.checkArgument(children().size() >= 1);
// in type coercion, the orderExpression will be the child of cast, we should push down cast.
children = children.stream().map(ExpressionUtils::pushDownCastInOrderExpression).collect(Collectors.toList());
boolean foundOrderExpr = false;
int firstOrderExrIndex = 0;
for (int i = 0; i < children.size(); i++) {
@ -118,11 +122,10 @@ public class GroupConcat extends AggregateFunction
}
}
List<OrderExpression> orders = (List) children.subList(firstOrderExrIndex, children.size());
if (firstOrderExrIndex == 1) {
List<OrderExpression> orders = (List) children.subList(firstOrderExrIndex, children.size());
return new GroupConcat(distinct, children.get(0), orders.toArray(new OrderExpression[0]));
} else if (firstOrderExrIndex == 2) {
List<OrderExpression> orders = (List) children.subList(firstOrderExrIndex, children.size());
return new GroupConcat(distinct, children.get(0), children.get(1), orders.toArray(new OrderExpression[0]));
} else {
throw new AnalysisException("group_concat requires one or two parameters: " + children);

View File

@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
@ -466,4 +467,17 @@ public class ExpressionUtils {
.flatMap(expr -> expr.getInputSlots().stream())
.collect(ImmutableSet.toImmutableSet());
}
/**
* cast push down in order expression
*/
public static Expression pushDownCastInOrderExpression(Expression expression) {
if (expression instanceof Cast
&& ((Cast) expression).child() instanceof OrderExpression) {
Cast cast = ((Cast) expression);
OrderExpression order = ((OrderExpression) cast.child());
return order.withChildren(cast.withChildren(order.getOrderKey().getExpr()));
}
return expression;
}
}