[feature-wip](nereids) Made decimal in nereids more complete (#15087)

1. Add IntegralDivide operator to support `DIV` semantics
2. Add more operator rewriter to keep expression type consistent between operators
3. Support the convertion between float type and decimal type.

After this PR, below cases could be executed normaly like the legacy optimizer:
  use test_query_db;
  select k1, k5,100000*k5 from test order by k1, k2, k3, k4;
  select avg(k9) as a from test group by k1 having a < 100.0 order by a;
This commit is contained in:
Kikyou1997
2022-12-29 13:01:47 +08:00
committed by GitHub
parent 29492f0d6c
commit 5b09d27d54
18 changed files with 642 additions and 12 deletions

View File

@ -31,6 +31,9 @@ import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
import org.apache.doris.nereids.rules.analysis.ResolveOrdinalInOrderByAndGroupBy;
import org.apache.doris.nereids.rules.analysis.Scope;
import org.apache.doris.nereids.rules.analysis.UserAuthentication;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.rewrite.rules.CharacterLiteralTypeCoercion;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
import org.apache.doris.nereids.rules.rewrite.logical.HideOneRowRelationUnderUnion;
import com.google.common.collect.ImmutableList;
@ -68,7 +71,9 @@ public class AnalyzeRulesJob extends BatchRulesJob {
new ProjectWithDistinctToAggregate(),
new ResolveOrdinalInOrderByAndGroupBy(),
new ReplaceExpressionByChildOutput(),
new HideOneRowRelationUnderUnion()
new HideOneRowRelationUnderUnion(),
new ExpressionNormalization(cascadesContext.getConnectContext(),
ImmutableList.of(CharacterLiteralTypeCoercion.INSTANCE, TypeCoercion.INSTANCE))
)),
topDownBatch(ImmutableList.of(
new FillUpMissingSlots(),

View File

@ -122,6 +122,7 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
@ -672,7 +673,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
case DorisParser.MINUS:
return new Subtract(left, right);
case DorisParser.DIV:
return new Divide(left, right);
return new IntegralDivide(left, right);
case DorisParser.HAT:
return new BitXor(left, right);
case DorisParser.PIPE:

View File

@ -135,6 +135,9 @@ public enum RuleType {
REWRITE_FILTER_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_JOIN_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_GENERATE_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_SORT_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_HAVING_EXPRESSSION(RuleTypeClass.REWRITE),
REWRITE_REPEAT_EXPRESSSION(RuleTypeClass.REWRITE),
REORDER_JOIN(RuleTypeClass.REWRITE),
// Merge Consecutive plan
MERGE_PROJECTS(RuleTypeClass.REWRITE),

View File

@ -66,6 +66,7 @@ public class ResolveOrdinalInOrderByAndGroupBy implements AnalysisRuleFactory {
logicalAggregate().then(agg -> {
List<NamedExpression> aggOutput = agg.getOutputExpressions();
List<Expression> groupByWithoutOrd = new ArrayList<>();
boolean ordExists = false;
for (Expression groupByExpr : agg.getGroupByExpressions()) {
groupByExpr = FoldConstantRule.INSTANCE.rewrite(groupByExpr);
if (groupByExpr instanceof IntegerLikeLiteral) {
@ -74,11 +75,17 @@ public class ResolveOrdinalInOrderByAndGroupBy implements AnalysisRuleFactory {
checkOrd(ord, aggOutput.size());
Expression aggExpr = aggOutput.get(ord - 1);
groupByWithoutOrd.add(aggExpr);
ordExists = true;
} else {
groupByWithoutOrd.add(groupByExpr);
}
}
return new LogicalAggregate(groupByWithoutOrd, agg.getOutputExpressions(), agg.child());
if (ordExists) {
return new LogicalAggregate(groupByWithoutOrd, agg.getOutputExpressions(), agg.child());
} else {
return agg;
}
}))).build();
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
@ -35,6 +36,8 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@ -66,7 +69,10 @@ public class ExpressionRewrite implements RewriteRuleFactory {
new ProjectExpressionRewrite().build(),
new AggExpressionRewrite().build(),
new FilterExpressionRewrite().build(),
new JoinExpressionRewrite().build());
new JoinExpressionRewrite().build(),
new SortExpressionRewrite().build(),
new LogicalRepeatRewrite().build(),
new HavingExpressionRewrite().build());
}
private class GenerateExpressionRewrite extends OneRewriteRuleFactory {
@ -183,4 +189,48 @@ public class ExpressionRewrite implements RewriteRuleFactory {
}).toRule(RuleType.REWRITE_JOIN_EXPRESSION);
}
}
private class SortExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalSort().then(sort -> {
List<OrderKey> orderKeys = sort.getOrderKeys();
List<OrderKey> rewrittenOrderKeys = new ArrayList<>();
for (OrderKey k : orderKeys) {
Expression expression = rewriter.rewrite(k.getExpr());
rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst()));
}
return sort.withOrderByKey(rewrittenOrderKeys);
}).toRule(RuleType.REWRITE_SORT_EXPRESSION);
}
}
private class HavingExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalHaving().then(having -> {
Set<Expression> rewrittenExpr = new HashSet<>();
for (Expression e : having.getExpressions()) {
rewrittenExpr.add(rewriter.rewrite(e));
}
return having.withExpressions(rewrittenExpr);
}).toRule(RuleType.REWRITE_HAVING_EXPRESSSION);
}
}
private class LogicalRepeatRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalRepeat().then(r -> {
List<List<Expression>> groupingExprs = new ArrayList<>();
for (List<Expression> expressions : r.getGroupingSets()) {
groupingExprs.add(expressions.stream().map(rewriter::rewrite).collect(Collectors.toList()));
}
return r.withGroupSetsAndOutput(groupingExprs,
r.getOutputExpressions().stream().map(rewriter::rewrite).map(e -> (NamedExpression) e)
.collect(Collectors.toList()));
}).toRule(RuleType.REWRITE_REPEAT_EXPRESSSION);
}
}
}

View File

@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
@ -218,4 +219,12 @@ public class TypeCoercion extends AbstractExpressionRewriteRule {
}
});
}
@Override
public Expression visitIntegralDivide(IntegralDivide integralDivide, ExpressionRewriteContext context) {
DataType commonType = BigIntType.INSTANCE;
Expression newLeft = TypeCoercionUtils.castIfNotSameType(integralDivide.left(), commonType);
Expression newRight = TypeCoercionUtils.castIfNotSameType(integralDivide.right(), commonType);
return integralDivide.withChildren(newLeft, newRight);
}
}

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.Subtract;
@ -150,7 +151,7 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta
.setNumNulls(numNulls).setDataSize(dataSize).setMinValue(min).setMaxValue(max).setSelectivity(1.0)
.setMaxExpr(null).setMinExpr(null).build();
}
if (binaryArithmetic instanceof Divide) {
if (binaryArithmetic instanceof Divide || binaryArithmetic instanceof IntegralDivide) {
double min = Math.min(
Math.min(
Math.min(leftMin / noneZeroDivisor(rightMin), leftMin / noneZeroDivisor(rightMax)),

View File

@ -0,0 +1,60 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.types.coercion.NumericType;
import com.google.common.base.Preconditions;
import java.util.List;
/**
* A DIV B
*/
public class IntegralDivide extends BinaryArithmetic {
public IntegralDivide(Expression left, Expression right) {
super(left, right, Operator.INT_DIVIDE);
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitIntegralDivide(this, context);
}
@Override
public AbstractDataType inputType() {
return NumericType.INSTANCE;
}
// Divide is implemented as a scalar function which return type is always nullable.
@Override
public boolean nullable() throws UnboundException {
return true;
}
@Override
public Expression withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new IntegralDivide(children.get(0), children.get(1));
}
}

View File

@ -46,6 +46,7 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
@ -391,6 +392,10 @@ public abstract class ExpressionVisitor<R, C>
return visit(boundStar, context);
}
public R visitIntegralDivide(IntegralDivide integralDivide, C context) {
return visitBinaryArithmetic(integralDivide, context);
}
/* ********************************************************************************************
* Unbound expressions
* ********************************************************************************************/

View File

@ -84,6 +84,11 @@ public class LogicalHaving<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T
return new LogicalHaving<>(conjuncts, Optional.empty(), logicalProperties, child());
}
public Plan withExpressions(Set<Expression> expressions) {
return new LogicalHaving<Plan>(expressions, Optional.empty(),
Optional.of(getLogicalProperties()), child());
}
@Override
public List<Slot> computeOutput() {
return child().getOutput();

View File

@ -156,6 +156,14 @@ public class TypeCoercionUtils {
|| leftType instanceof IntegralType && rightType instanceof DecimalV2Type) {
return true;
}
if (leftType instanceof FloatType && rightType instanceof DecimalV2Type
|| leftType instanceof DecimalV2Type && rightType instanceof FloatType) {
return true;
}
if (leftType instanceof DoubleType && rightType instanceof DecimalV2Type
|| leftType instanceof DecimalV2Type && rightType instanceof DoubleType) {
return true;
}
// TODO: add decimal promotion support
if (!(leftType instanceof DecimalV2Type)
&& !(rightType instanceof DecimalV2Type)
@ -230,7 +238,7 @@ public class TypeCoercionUtils {
|| (right instanceof DateLikeType && left instanceof IntegralType)) {
tightestCommonType = BigIntType.INSTANCE;
}
return Optional.ofNullable(tightestCommonType);
return tightestCommonType == null ? Optional.of(DoubleType.INSTANCE) : Optional.of(tightestCommonType);
}
/**

View File

@ -32,6 +32,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
@ -312,14 +313,16 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt
ExceptionChecker.expectThrowsWithMsg(
AnalysisException.class,
"Aggregate functions in having clause can't be nested: sum((a1 + avg(a2))).",
"Aggregate functions in having clause can't be nested:"
+ " sum(cast((cast(a1 as DOUBLE) + avg(cast(a2 as DOUBLE))) as SMALLINT)).",
() -> PlanChecker.from(connectContext).analyze(
"SELECT a1 FROM t1 GROUP BY a1 HAVING SUM(a1 + AVG(a2)) > 0"
));
ExceptionChecker.expectThrowsWithMsg(
AnalysisException.class,
"Aggregate functions in having clause can't be nested: sum(((a1 + a2) + avg(a2))).",
"Aggregate functions in having clause can't be nested:"
+ " sum(cast((cast((a1 + a2) as DOUBLE) + avg(cast(a2 as DOUBLE))) as INT)).",
() -> PlanChecker.from(connectContext).analyze(
"SELECT a1 FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + AVG(a2)) > 0"
));
@ -530,8 +533,8 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt
ImmutableList.of(
new OrderKey(pk, true, true),
new OrderKey(countA11.toSlot(), true, true),
new OrderKey(new Add(sumA1A2.toSlot(), new TinyIntLiteral((byte) 1)), true, true),
new OrderKey(new Add(v1.toSlot(), new TinyIntLiteral((byte) 1)), true, true),
new OrderKey(new Add(sumA1A2.toSlot(), new BigIntLiteral((byte) 1)), true, true),
new OrderKey(new Add(v1.toSlot(), new BigIntLiteral((byte) 1)), true, true),
new OrderKey(v1.toSlot(), true, true)
)
))

View File

@ -145,7 +145,7 @@ public class TypeCoercionUtilsTest {
testFindTightestCommonType(BigIntType.INSTANCE, IntegerType.INSTANCE, BigIntType.INSTANCE);
testFindTightestCommonType(StringType.INSTANCE, StringType.INSTANCE, IntegerType.INSTANCE);
testFindTightestCommonType(StringType.INSTANCE, IntegerType.INSTANCE, StringType.INSTANCE);
testFindTightestCommonType(null, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.createDecimalV2Type(2, 1));
testFindTightestCommonType(DoubleType.INSTANCE, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.createDecimalV2Type(2, 1));
testFindTightestCommonType(VarcharType.createVarcharType(10), CharType.createCharType(8), CharType.createCharType(10));
testFindTightestCommonType(VarcharType.createVarcharType(10), VarcharType.createVarcharType(8), VarcharType.createVarcharType(10));
testFindTightestCommonType(VarcharType.createVarcharType(10), VarcharType.createVarcharType(8), CharType.createCharType(10));