[enhancement](Nereids)support count, min and avg function (#11374)

1. add count function
2. add min function
3. add avg function
This commit is contained in:
yinzhijian
2022-08-04 21:19:32 +08:00
committed by GitHub
parent 346fdeeee0
commit 6dc41d57f3
17 changed files with 491 additions and 8 deletions

View File

@ -1394,15 +1394,20 @@ public class FunctionCallExpr extends Expr {
if (fnName.getFunction().equalsIgnoreCase("sum")) {
// Prevent the cast type in vector exec engine
Type childType = getChild(0).type.getMaxResolutionType();
fn = getBuiltinFunction(fnName.getFunction(), new Type[]{childType},
fn = getBuiltinFunction(fnName.getFunction(), new Type[] {childType},
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
type = fn.getReturnType();
} else if (fnName.getFunction().equalsIgnoreCase("count")) {
fn = getBuiltinFunction(fnName.getFunction(), new Type[0], Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
type = fn.getReturnType();
} else if (fnName.getFunction().equalsIgnoreCase("substring")
|| fnName.getFunction().equalsIgnoreCase("cast")) {
Type[] childTypes = getChildren().stream().map(t -> t.type).toArray(Type[]::new);
fn = getBuiltinFunction(fnName.getFunction(), childTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
type = fn.getReturnType();
} else if (fnName.getFunction().equalsIgnoreCase("year")) {
} else if (fnName.getFunction().equalsIgnoreCase("year")
|| fnName.getFunction().equalsIgnoreCase("min")
|| fnName.getFunction().equalsIgnoreCase("avg")) {
Type childType = getChild(0).type;
fn = getBuiltinFunction(fnName.getFunction(), new Type[]{childType},
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);

View File

@ -34,11 +34,13 @@ public class UnboundFunction extends Expression implements Unbound {
private final String name;
private final boolean isDistinct;
private final boolean isStar;
public UnboundFunction(String name, boolean isDistinct, List<Expression> arguments) {
public UnboundFunction(String name, boolean isDistinct, boolean isStar, List<Expression> arguments) {
super(arguments.toArray(new Expression[0]));
this.name = Objects.requireNonNull(name, "name can not be null");
this.isDistinct = isDistinct;
this.isStar = isStar;
}
public String getName() {
@ -49,6 +51,10 @@ public class UnboundFunction extends Expression implements Unbound {
return isDistinct;
}
public boolean isStar() {
return isStar;
}
public List<Expression> getArguments() {
return children();
}
@ -74,7 +80,7 @@ public class UnboundFunction extends Expression implements Unbound {
@Override
public Expression withChildren(List<Expression> children) {
return new UnboundFunction(name, isDistinct, children);
return new UnboundFunction(name, isDistinct, isStar, children);
}
@Override

View File

@ -27,6 +27,7 @@ import org.apache.doris.analysis.CastExpr;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FloatLiteral;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.FunctionParams;
import org.apache.doris.analysis.IntLiteral;
import org.apache.doris.analysis.LikePredicate;
import org.apache.doris.analysis.NullLiteral;
@ -59,6 +60,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.Count;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import java.util.ArrayList;
@ -260,6 +262,12 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
for (Expression expr : function.getArguments()) {
paramList.add(expr.accept(this, context));
}
if (function instanceof Count) {
Count count = (Count) function;
if (count.isStar()) {
return new FunctionCallExpr(function.getName(), FunctionParams.createStarParam());
}
}
return new FunctionCallExpr(function.getName(), paramList);
}

View File

@ -123,6 +123,7 @@ import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.RuleNode;
import org.antlr.v4.runtime.tree.TerminalNode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@ -453,7 +454,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
public UnboundFunction visitExtract(DorisParser.ExtractContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {
String functionName = ctx.field.getText();
return new UnboundFunction(functionName, false, Arrays.asList(getExpression(ctx.source)));
return new UnboundFunction(functionName, false, false, Arrays.asList(getExpression(ctx.source)));
});
}
@ -465,7 +466,12 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
String functionName = ctx.identifier().getText();
boolean isDistinct = ctx.DISTINCT() != null;
List<Expression> params = visit(ctx.expression(), Expression.class);
return new UnboundFunction(functionName, isDistinct, params);
for (Expression expression : params) {
if (expression instanceof UnboundStar && functionName.equalsIgnoreCase("count") && !isDistinct) {
return new UnboundFunction(functionName, false, true, new ArrayList<>());
}
}
return new UnboundFunction(functionName, isDistinct, false, params);
});
}

View File

@ -24,6 +24,9 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.Avg;
import org.apache.doris.nereids.trees.expressions.functions.Count;
import org.apache.doris.nereids.trees.expressions.functions.Min;
import org.apache.doris.nereids.trees.expressions.functions.Substring;
import org.apache.doris.nereids.trees.expressions.functions.Sum;
import org.apache.doris.nereids.trees.expressions.functions.Year;
@ -90,6 +93,27 @@ public class BindFunction implements AnalysisRuleFactory {
return unboundFunction;
}
return new Sum(unboundFunction.getArguments().get(0));
} else if (name.equalsIgnoreCase("count")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() > 1 || (arguments.size() == 0 && !unboundFunction.isStar())) {
return unboundFunction;
}
if (unboundFunction.isStar()) {
return new Count();
}
return new Count(unboundFunction.getArguments().get(0));
} else if (name.equalsIgnoreCase("min")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() != 1) {
return unboundFunction;
}
return new Min(unboundFunction.getArguments().get(0));
} else if (name.equalsIgnoreCase("avg")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() != 1) {
return unboundFunction;
}
return new Avg(unboundFunction.getArguments().get(0));
} else if (name.equalsIgnoreCase("substr") || name.equalsIgnoreCase("substring")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() == 2) {

View File

@ -282,6 +282,10 @@ public class BindSlotReference implements AnalysisRuleFactory {
);
}
public String toSql() {
return children.stream().map(Expression::toSql).collect(Collectors.joining(", "));
}
public List<Slot> getSlots() {
return (List) children();
}

View File

@ -0,0 +1,57 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.UnaryExpression;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.VarcharType;
import com.google.common.base.Preconditions;
import java.util.List;
/** avg agg function. */
public class Avg extends AggregateFunction implements UnaryExpression {
public Avg(Expression child) {
super("avg", child);
}
@Override
public DataType getDataType() {
return DoubleType.INSTANCE;
}
@Override
public boolean nullable() {
return child().nullable();
}
@Override
public Expression withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new Avg(children.get(0));
}
@Override
public DataType getIntermediateType() {
return VarcharType.createVarcharType(-1);
}
}

View File

@ -0,0 +1,96 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.stream.Collectors;
/** count agg function. */
public class Count extends AggregateFunction {
private final boolean isStar;
public Count() {
super("count");
this.isStar = true;
}
public Count(Expression child) {
super("count", child);
this.isStar = false;
}
public boolean isStar() {
return isStar;
}
@Override
public DataType getDataType() {
return BigIntType.INSTANCE;
}
@Override
public boolean nullable() {
return false;
}
@Override
public Expression withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 0 || children.size() == 1);
if (children.size() == 0) {
return new Count();
}
return new Count(children.get(0));
}
@Override
public DataType getIntermediateType() {
return getDataType();
}
@Override
public String toSql() throws UnboundException {
if (isStar) {
return "count(*)";
}
String args = children()
.stream()
.map(Expression::toSql)
.collect(Collectors.joining(", "));
return "count(" + args + ")";
}
@Override
public String toString() {
if (isStar) {
return "count(*)";
}
String args = children()
.stream()
.map(Expression::toString)
.collect(Collectors.joining(", "));
return "count(" + args + ")";
}
}

View File

@ -0,0 +1,55 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.UnaryExpression;
import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
import java.util.List;
/** min agg function. */
public class Min extends AggregateFunction implements UnaryExpression {
public Min(Expression child) {
super("min", child);
}
@Override
public DataType getDataType() {
return child().getDataType();
}
@Override
public boolean nullable() {
return child().nullable();
}
@Override
public Expression withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new Min(children.get(0));
}
@Override
public DataType getIntermediateType() {
return getDataType();
}
}

View File

@ -154,8 +154,8 @@ public class ExpressionEqualsTest {
@Test
public void testUnboundFunction() {
UnboundFunction unboundFunction1 = new UnboundFunction("name", false, Lists.newArrayList(child1));
UnboundFunction unboundFunction2 = new UnboundFunction("name", false, Lists.newArrayList(child2));
UnboundFunction unboundFunction1 = new UnboundFunction("name", false, false, Lists.newArrayList(child1));
UnboundFunction unboundFunction2 = new UnboundFunction("name", false, false, Lists.newArrayList(child2));
Assertions.assertEquals(unboundFunction1, unboundFunction2);
Assertions.assertEquals(unboundFunction1.hashCode(), unboundFunction2.hashCode());
}

View File

@ -106,6 +106,12 @@ public class ExpressionParserTest {
String substring = "select substr(a, 1, 2), substring(b ,3 ,4) from test1";
assertSql(substring);
String count = "select count(*), count(b) from test1";
assertSql(count);
String min = "select min(a), min(b) as m from test1";
assertSql(min);
}
@Test

View File

@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("tpch_sf1_q02_nereids") {
String realDb = context.config.getDbNameByFile(context.file)
// get parent directory's group
realDb = realDb.substring(0, realDb.lastIndexOf("_"))
sql "use ${realDb}"
sql 'set enable_nereids_planner=true'
// nereids need vectorized
sql 'set enable_vectorized_engine=true'
sql 'set exec_mem_limit=2147483648*2'
test {
sql(new File(context.file.parentFile, "../sql/q02.sql").text)
resultFile(file = "../sql/q02.out", tag = "q02")
}
}

View File

@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("tpch_sf1_q06_nereids") {
String realDb = context.config.getDbNameByFile(context.file)
// get parent directory's group
realDb = realDb.substring(0, realDb.lastIndexOf("_"))
sql "use ${realDb}"
sql 'set enable_nereids_planner=true'
// nereids need vectorized
sql 'set enable_vectorized_engine=true'
sql 'set exec_mem_limit=2147483648*2'
test {
sql(new File(context.file.parentFile, "../sql/q06.sql").text)
resultFile(file = "../sql/q06.out", tag = "q06")
}
}

View File

@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("tpch_sf1_q13_nereids") {
String realDb = context.config.getDbNameByFile(context.file)
// get parent directory's group
realDb = realDb.substring(0, realDb.lastIndexOf("_"))
sql "use ${realDb}"
sql 'set enable_nereids_planner=true'
// nereids need vectorized
sql 'set enable_vectorized_engine=true'
sql 'set exec_mem_limit=2147483648*2'
test {
sql(new File(context.file.parentFile, "../sql/q13.sql").text)
resultFile(file = "../sql/q13.out", tag = "q13")
}
}

View File

@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("tpch_sf1_q17_nereids") {
String realDb = context.config.getDbNameByFile(context.file)
// get parent directory's group
realDb = realDb.substring(0, realDb.lastIndexOf("_"))
sql "use ${realDb}"
sql 'set enable_nereids_planner=true'
// nereids need vectorized
sql 'set enable_vectorized_engine=true'
sql 'set exec_mem_limit=2147483648*2'
test {
sql(new File(context.file.parentFile, "../sql/q17.sql").text)
resultFile(file = "../sql/q17.out", tag = "q17")
}
}

View File

@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("tpch_sf1_q21_nereids") {
String realDb = context.config.getDbNameByFile(context.file)
// get parent directory's group
realDb = realDb.substring(0, realDb.lastIndexOf("_"))
sql "use ${realDb}"
sql 'set enable_nereids_planner=true'
// nereids need vectorized
sql 'set enable_vectorized_engine=true'
sql 'set exec_mem_limit=2147483648*2'
test {
sql(new File(context.file.parentFile, "../sql/q21.sql").text)
resultFile(file = "../sql/q21.out", tag = "q21")
}
}

View File

@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("tpch_sf1_q22_nereids") {
String realDb = context.config.getDbNameByFile(context.file)
// get parent directory's group
realDb = realDb.substring(0, realDb.lastIndexOf("_"))
sql "use ${realDb}"
sql 'set enable_nereids_planner=true'
// nereids need vectorized
sql 'set enable_vectorized_engine=true'
sql 'set exec_mem_limit=2147483648*2'
test {
sql(new File(context.file.parentFile, "../sql/q22.sql").text)
resultFile(file = "../sql/q22.out", tag = "q22")
}
}