[fix](nereids)enable decimalv3 by default for nereids (#19906)

This commit is contained in:
starocean999
2023-05-24 13:36:24 +08:00
committed by GitHub
parent f14e6189a9
commit 70f2e8ff80
21 changed files with 157 additions and 67 deletions

View File

@ -30,6 +30,7 @@ import com.google.gson.annotations.SerializedName;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
@ -1111,6 +1112,12 @@ public class ScalarType extends Type {
return getAssignmentCompatibleDecimalV2Type(t1, t2);
}
if ((t1.isDecimalV3() && t2.isDecimalV2()) || (t2.isDecimalV3() && t1.isDecimalV2())) {
int scale = Math.max(t1.scale, t2.scale);
int integerPart = Math.max(t1.precision - t1.scale, t2.precision - t2.scale);
return ScalarType.createDecimalV3Type(integerPart + scale, scale);
}
if (t1.isDecimalV2() || t2.isDecimalV2()) {
if (t1.isFloatingPointType() || t2.isFloatingPointType()) {
return MAX_DECIMALV2_TYPE;
@ -1118,8 +1125,42 @@ public class ScalarType extends Type {
return t1.isDecimalV2() ? t1 : t2;
}
if ((t1.isDecimalV3() && t2.isFixedPointType()) || (t2.isDecimalV3() && t1.isFixedPointType())) {
return t1.isDecimalV3() ? t1 : t2;
if (t1.isDecimalV3() || t2.isDecimalV3()) {
if (t1.isFloatingPointType() || t2.isFloatingPointType()) {
return t1.isFloatingPointType() ? t1 : t2;
} else if (t1.isBoolean() || t2.isBoolean()) {
return t1.isDecimalV3() ? t1 : t2;
}
}
if ((t1.isDecimalV3() && t2.isFixedPointType())
|| (t2.isDecimalV3() && t1.isFixedPointType())) {
int precision;
int scale;
ScalarType intType;
if (t1.isDecimalV3()) {
precision = t1.precision;
scale = t1.scale;
intType = t2;
} else {
precision = t2.precision;
scale = t2.scale;
intType = t1;
}
int integerPart = precision - scale;
if (intType.isScalarType(PrimitiveType.TINYINT)
|| intType.isScalarType(PrimitiveType.SMALLINT)) {
integerPart = Math.max(integerPart, new BigDecimal(Short.MAX_VALUE).precision());
} else if (intType.isScalarType(PrimitiveType.INT)) {
integerPart = Math.max(integerPart, new BigDecimal(Integer.MAX_VALUE).precision());
} else {
integerPart = ScalarType.MAX_DECIMAL128_PRECISION - scale;
}
if (scale + integerPart <= ScalarType.MAX_DECIMAL128_PRECISION) {
return ScalarType.createDecimalV3Type(scale + integerPart, scale);
} else {
return Type.DOUBLE;
}
}
if (t1.isDecimalV3() && t2.isDecimalV3()) {

View File

@ -1813,8 +1813,12 @@ public abstract class Type {
} else {
resultDecimalType = PrimitiveType.DECIMAL128;
}
return ScalarType.createDecimalType(resultDecimalType, resultPrecision,
Math.max(((ScalarType) t1).getScalarScale(), ((ScalarType) t2).getScalarScale()));
if (resultPrecision <= ScalarType.MAX_DECIMAL128_PRECISION) {
return ScalarType.createDecimalType(resultDecimalType, resultPrecision, Math.max(
((ScalarType) t1).getScalarScale(), ((ScalarType) t2).getScalarScale()));
} else {
return Type.DOUBLE;
}
}
if (t1ResultType.isDecimalV3Type() || t2ResultType.isDecimalV3Type()) {
return getAssignmentCompatibleType(t1, t2, false);

View File

@ -414,9 +414,13 @@ public class BinaryPredicate extends Predicate implements Writable {
if (t1 == PrimitiveType.BIGINT && t2 == PrimitiveType.BIGINT) {
return Type.getAssignmentCompatibleType(getChild(0).getType(), getChild(1).getType(), false);
}
if ((t1 == PrimitiveType.BIGINT || t1 == PrimitiveType.DECIMALV2)
&& (t2 == PrimitiveType.BIGINT || t2 == PrimitiveType.DECIMALV2)) {
return Type.DECIMALV2;
if ((t1 == PrimitiveType.BIGINT && t2 == PrimitiveType.DECIMALV2)
|| (t2 == PrimitiveType.BIGINT && t1 == PrimitiveType.DECIMALV2)
|| (t1 == PrimitiveType.LARGEINT && t2 == PrimitiveType.DECIMALV2)
|| (t2 == PrimitiveType.LARGEINT && t1 == PrimitiveType.DECIMALV2)) {
// only decimalv3 can hold big and large int
return ScalarType.createDecimalType(PrimitiveType.DECIMAL128, ScalarType.MAX_DECIMAL128_PRECISION,
ScalarType.MAX_DECIMALV2_SCALE);
}
if ((t1 == PrimitiveType.BIGINT || t1 == PrimitiveType.LARGEINT)
&& (t2 == PrimitiveType.BIGINT || t2 == PrimitiveType.LARGEINT)) {
@ -603,9 +607,9 @@ public class BinaryPredicate extends Predicate implements Writable {
}
public Range<LiteralExpr> convertToRange() {
Preconditions.checkState(getChild(0) instanceof SlotRef);
Preconditions.checkState(getChild(1) instanceof LiteralExpr);
LiteralExpr literalExpr = (LiteralExpr) getChild(1);
Preconditions.checkState(getChildWithoutCast(0) instanceof SlotRef);
Preconditions.checkState(getChildWithoutCast(1) instanceof LiteralExpr);
LiteralExpr literalExpr = (LiteralExpr) getChildWithoutCast(1);
switch (op) {
case EQ:
return Range.singleton(literalExpr);

View File

@ -286,8 +286,10 @@ public class DecimalLiteral extends LiteralExpr {
@Override
protected void compactForLiteral(Type type) throws AnalysisException {
if (type.isDecimalV3()) {
this.type = ScalarType.createDecimalV3Type(Math.max(this.value.precision(), type.getPrecision()),
Math.max(this.value.scale(), ((ScalarType) type).decimalScale()));
int scale = Math.max(this.value.scale(), ((ScalarType) type).decimalScale());
int integerPart = Math.max(this.value.precision() - this.value.scale(),
type.getPrecision() - ((ScalarType) type).decimalScale());
this.type = ScalarType.createDecimalV3Type(integerPart + scale, scale);
}
}

View File

@ -495,6 +495,12 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
return result;
}
public Expr getChildWithoutCast(int i) {
Preconditions.checkArgument(i < children.size(), "child index {0} out of range {1}", i, children.size());
Expr child = children.get(i);
return child instanceof CastExpr ? child.children.get(0) : child;
}
/**
* Helper function: analyze list of exprs
*

View File

@ -1404,23 +1404,24 @@ public class FunctionCallExpr extends Expr {
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else if ((fnName.getFunction().equalsIgnoreCase("coalesce")
|| fnName.getFunction().equalsIgnoreCase("greatest")
|| fnName.getFunction().equalsIgnoreCase("least")) && children.size() > 1) {
|| fnName.getFunction().equalsIgnoreCase("least")
|| fnName.getFunction().equalsIgnoreCase("greatest")) && children.size() > 1) {
Type[] childTypes = collectChildReturnTypes();
Type assignmentCompatibleType = childTypes[0];
for (int i = 1; i < childTypes.length && assignmentCompatibleType.isDecimalV3(); i++) {
assignmentCompatibleType =
ScalarType.getAssignmentCompatibleType(assignmentCompatibleType, childTypes[i], true);
for (int i = 1; i < childTypes.length; i++) {
assignmentCompatibleType = ScalarType
.getAssignmentCompatibleType(assignmentCompatibleType, childTypes[i], true);
}
if (assignmentCompatibleType.isDecimalV3()) {
for (int i = 0; i < childTypes.length; i++) {
if (assignmentCompatibleType.isDecimalV3() && !childTypes[i].equals(assignmentCompatibleType)) {
if (assignmentCompatibleType.isDecimalV3()
&& !childTypes[i].equals(assignmentCompatibleType)) {
uncheckedCastChild(assignmentCompatibleType, i);
argTypes[i] = assignmentCompatibleType;
}
}
}
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
fn = getBuiltinFunction(fnName.getFunction(), argTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else if (AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains(
fnName.getFunction().toLowerCase())) {
@ -1658,6 +1659,7 @@ public class FunctionCallExpr extends Expr {
} else if (!argTypes[i].matchesType(args[ix])
&& (!fn.getReturnType().isDecimalV3OrContainsDecimalV3()
|| (argTypes[i].isValid() && !argTypes[i].isDecimalV3() && args[ix].isDecimalV3()))) {
// || (argTypes[i].isValid() && argTypes[i].getPrimitiveType() != args[ix].getPrimitiveType()))) {
uncheckedCastChild(args[ix], i);
}
}

View File

@ -63,15 +63,19 @@ public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule {
private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) {
BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros();
int scale = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalPrecision(trailingZerosValue);
Expression castChild = left.child();
Preconditions.checkState(castChild.getDataType() instanceof DecimalV3Type);
DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType();
// precision and scale of literal must all smaller than left, otherwise we need to do cast on right.
Preconditions.checkState(scale <= leftType.getScale(), "right scale should not greater than left");
Preconditions.checkState(precision <= leftType.getPrecision(), "right precision should not greater than left");
DecimalV3Literal newRight = new DecimalV3Literal(
DecimalV3Type.createDecimalV3Type(leftType.getPrecision(), leftType.getScale()), trailingZerosValue);
return cp.withChildren(castChild, newRight);
if (scale <= leftType.getScale() && precision - scale <= leftType.getPrecision() - leftType.getScale()) {
// precision and scale of literal all smaller than left, we don't need the cast
DecimalV3Literal newRight = new DecimalV3Literal(
DecimalV3Type.createDecimalV3Type(leftType.getPrecision(), leftType.getScale()),
trailingZerosValue);
return cp.withChildren(castChild, newRight);
} else {
return cp;
}
}
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.catalog.FunctionSignature.TripleFunction;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeV2Type;
@ -32,6 +33,7 @@ import org.apache.doris.nereids.util.ResponsibilityChain;
import com.google.common.base.Preconditions;
import java.math.BigDecimal;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
@ -160,8 +162,15 @@ public class ComputeSignatureHelper {
if (finalType == null) {
finalType = DecimalV3Type.forType(arguments.get(i).getDataType());
} else {
finalType = DecimalV3Type.widerDecimalV3Type((DecimalV3Type) finalType,
DecimalV3Type.forType(arguments.get(i).getDataType()), true);
Expression arg = arguments.get(i);
DecimalV3Type argType;
if (arg.isLiteral() && arg.getDataType().isIntegralType()) {
// create decimalV3 with minimum scale enough to hold the integral literal
argType = DecimalV3Type.createDecimalV3Type(new BigDecimal(((Literal) arg).getStringValue()));
} else {
argType = DecimalV3Type.forType(arg.getDataType());
}
finalType = DecimalV3Type.widerDecimalV3Type((DecimalV3Type) finalType, argType, true);
}
Preconditions.checkState(finalType.isDecimalV3Type(),
"decimalv3 precision promotion failed.");

View File

@ -28,6 +28,7 @@ import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.Lists;
import java.math.BigDecimal;
import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;
@ -141,8 +142,14 @@ public class SearchSignature {
if (finalType == null) {
finalType = DecimalV3Type.forType(arguments.get(i).getDataType());
} else {
finalType = DecimalV3Type.widerDecimalV3Type((DecimalV3Type) finalType,
DecimalV3Type.forType(arguments.get(i).getDataType()), true);
Expression arg = arguments.get(i);
if (arg.isLiteral() && arg.getDataType().isIntegralType()) {
// create decimalV3 with minimum scale enough to hold the integral literal
finalType = DecimalV3Type.createDecimalV3Type(new BigDecimal(((Literal) arg).getStringValue()));
} else {
finalType = DecimalV3Type.widerDecimalV3Type((DecimalV3Type) finalType,
DecimalV3Type.forType(arg.getDataType()), true);
}
}
if (!finalType.isDecimalV3Type()) {
return false;

View File

@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
@ -44,14 +43,13 @@ public class AvgWeighted extends AggregateFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV3Type.WILDCARD, DoubleType.INSTANCE)
FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT, DoubleType.INSTANCE)
);
/**

View File

@ -31,6 +31,7 @@ import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
@ -60,7 +61,8 @@ public class Histogram extends AggregateFunction
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(LargeIntType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(FloatType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DoubleType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DecimalV2Type.CATALOG_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DecimalV3Type.WILDCARD),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateTimeType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateV2Type.INSTANCE),

View File

@ -44,12 +44,12 @@ public class Stddev extends NullableAggregateFunction
StdDevOrVarianceFunction, DecimalStddevPrecision {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);

View File

@ -45,12 +45,12 @@ public class StddevSamp extends AggregateFunction
StdDevOrVarianceFunction, DecimalStddevPrecision {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);

View File

@ -44,12 +44,12 @@ public class Variance extends NullableAggregateFunction
StdDevOrVarianceFunction, DecimalStddevPrecision {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);

View File

@ -44,12 +44,12 @@ public class VarianceSamp extends AggregateFunction
StdDevOrVarianceFunction, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);

View File

@ -64,8 +64,8 @@ public class Coalesce extends ScalarFunction
FunctionSignature.ret(DateType.INSTANCE).varArgs(DateType.INSTANCE),
FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT).varArgs(DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DateV2Type.INSTANCE).varArgs(DateV2Type.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).varArgs(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DecimalV3Type.WILDCARD).varArgs(DecimalV3Type.WILDCARD),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).varArgs(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(BitmapType.INSTANCE).varArgs(BitmapType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(StringType.INSTANCE).varArgs(StringType.INSTANCE)

View File

@ -82,10 +82,10 @@ public class If extends ScalarFunction
FunctionSignature.ret(DateTimeType.INSTANCE)
.args(BooleanType.INSTANCE, DateTimeType.INSTANCE, DateTimeType.INSTANCE),
FunctionSignature.ret(DateType.INSTANCE).args(BooleanType.INSTANCE, DateType.INSTANCE, DateType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT)
.args(BooleanType.INSTANCE, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DecimalV3Type.WILDCARD)
.args(BooleanType.INSTANCE, DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT)
.args(BooleanType.INSTANCE, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(BitmapType.INSTANCE)
.args(BooleanType.INSTANCE, BitmapType.INSTANCE, BitmapType.INSTANCE),
FunctionSignature.ret(HllType.INSTANCE).args(BooleanType.INSTANCE, HllType.INSTANCE, HllType.INSTANCE),
@ -125,14 +125,14 @@ public class If extends ScalarFunction
ArrayType.of(DateTimeV2Type.SYSTEM_DEFAULT)),
FunctionSignature.ret(ArrayType.of(DateV2Type.INSTANCE))
.args(BooleanType.INSTANCE, ArrayType.of(DateV2Type.INSTANCE), ArrayType.of(DateV2Type.INSTANCE)),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(BooleanType.INSTANCE,
ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT),
ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT)),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(BooleanType.INSTANCE,
ArrayType.of(DecimalV3Type.WILDCARD),
ArrayType.of(DecimalV3Type.WILDCARD)),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(BooleanType.INSTANCE,
ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT),
ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT)),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(BooleanType.INSTANCE,
ArrayType.of(VarcharType.SYSTEM_DEFAULT),

View File

@ -68,10 +68,10 @@ public class Nvl extends ScalarFunction
.args(DateTimeV2Type.SYSTEM_DEFAULT, DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DateV2Type.INSTANCE)
.args(DateV2Type.INSTANCE, DateV2Type.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT)
.args(DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DecimalV3Type.WILDCARD)
.args(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT)
.args(DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE, BitmapType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),

View File

@ -37,8 +37,9 @@ public class DecimalLiteral extends Literal {
}
public DecimalLiteral(DecimalV2Type dataType, BigDecimal value) {
super(DecimalV2Type.createDecimalV2Type(dataType.getPrecision(), dataType.getScale()));
this.value = Objects.requireNonNull(value.setScale(dataType.getScale(), RoundingMode.DOWN));
super(dataType);
BigDecimal adjustedValue = value.scale() < 0 ? value : value.setScale(dataType.getScale(), RoundingMode.DOWN);
this.value = Objects.requireNonNull(adjustedValue);
}
@Override
@ -53,7 +54,7 @@ public class DecimalLiteral extends Literal {
@Override
public LiteralExpr toLegacyLiteral() {
return new org.apache.doris.analysis.DecimalLiteral(value);
return new org.apache.doris.analysis.DecimalLiteral(value, dataType.toCatalogDataType());
}
@Override

View File

@ -451,8 +451,8 @@ public class TypeCoercionUtils {
}
DataType commonType = DoubleType.INSTANCE;
if (t1.isDoubleType() || t1.isFloatType() || t1.isLargeIntType()
|| t2.isDoubleType() || t2.isFloatType() || t2.isLargeIntType()) {
if (t1.isDoubleType() || t1.isFloatType()
|| t2.isDoubleType() || t2.isFloatType()) {
// double type
} else if (t1.isDecimalV3Type() || t2.isDecimalV3Type()) {
// divide should cast to precision and target scale
@ -535,6 +535,9 @@ public class TypeCoercionUtils {
break;
}
}
if (commonType.isFloatType() && (t1.isDecimalV3Type() || t2.isDecimalV3Type())) {
commonType = DoubleType.INSTANCE;
}
boolean isBitArithmetic = binaryArithmetic instanceof BitAnd
|| binaryArithmetic instanceof BitOr
@ -577,13 +580,12 @@ public class TypeCoercionUtils {
return castChildren(binaryArithmetic, left, right, DoubleType.INSTANCE);
}
// add, subtract and mod should cast children to exactly same type as return type
// add, subtract should cast children to exactly same type as return type
if (binaryArithmetic instanceof Add
|| binaryArithmetic instanceof Subtract
|| binaryArithmetic instanceof Mod) {
|| binaryArithmetic instanceof Subtract) {
return castChildren(binaryArithmetic, left, right, retType);
}
// multiply do not need to cast children to same type
// multiply and mode do not need to cast children to same type
return binaryArithmetic.withChildren(castIfNotSameType(left, dt1), castIfNotSameType(right, dt2));
}
@ -959,8 +961,16 @@ public class TypeCoercionUtils {
DecimalV3Type.forType(leftType), DecimalV3Type.forType(rightType), true));
}
if (leftType instanceof DecimalV2Type || rightType instanceof DecimalV2Type) {
return Optional.of(DecimalV2Type.widerDecimalV2Type(
if (leftType instanceof BigIntType || rightType instanceof BigIntType
|| leftType instanceof LargeIntType || rightType instanceof LargeIntType) {
// only decimalv3 can hold big or large int
return Optional
.of(DecimalV3Type.widerDecimalV3Type(DecimalV3Type.forType(leftType),
DecimalV3Type.forType(rightType), true));
} else {
return Optional.of(DecimalV2Type.widerDecimalV2Type(
DecimalV2Type.forType(leftType), DecimalV2Type.forType(rightType)));
}
}
return Optional.of(commonType);
}

View File

@ -229,7 +229,7 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
if (!singleColumnPredicate(predicate)) {
continue;
}
SlotRef columnName = (SlotRef) predicate.getChild(0);
SlotRef columnName = (SlotRef) predicate.getChildWithoutCast(0);
if (predicate instanceof BinaryPredicate) {
Range<LiteralExpr> predicateRange = ((BinaryPredicate) predicate).convertToRange();
if (predicateRange == null) {
@ -319,14 +319,14 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
if (inPredicate.isNotIn()) {
return false;
}
if (inPredicate.getChild(0) instanceof SlotRef) {
if (inPredicate.getChildWithoutCast(0) instanceof SlotRef) {
return true;
}
return false;
} else if (expr instanceof BinaryPredicate) {
BinaryPredicate binaryPredicate = (BinaryPredicate) expr;
if (binaryPredicate.getChild(0) instanceof SlotRef
&& binaryPredicate.getChild(1) instanceof LiteralExpr) {
if (binaryPredicate.getChildWithoutCast(0) instanceof SlotRef
&& binaryPredicate.getChildWithoutCast(1) instanceof LiteralExpr) {
return true;
}
return false;
@ -518,9 +518,9 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
notMergedExprs.add(new CompoundPredicate(Operator.AND, left, right));
} else if (!(predicate instanceof BinaryPredicate) && !(predicate instanceof InPredicate)) {
notMergedExprs.add(predicate);
} else if (!(predicate.getChild(0) instanceof SlotRef)) {
} else if (!(predicate.getChildWithoutCast(0) instanceof SlotRef)) {
notMergedExprs.add(predicate);
} else if (!(predicate.getChild(1) instanceof LiteralExpr)) {
} else if (!(predicate.getChildWithoutCast(1) instanceof LiteralExpr)) {
notMergedExprs.add(predicate);
} else if (predicate instanceof BinaryPredicate
&& ((BinaryPredicate) predicate).getOp() != BinaryPredicate.Operator.EQ) {
@ -529,13 +529,13 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
&& ((InPredicate) predicate).isNotIn()) {
notMergedExprs.add(predicate);
} else {
TableName tableName = ((SlotRef) predicate.getChild(0)).getTableName();
TableName tableName = ((SlotRef) predicate.getChildWithoutCast(0)).getTableName();
String columnWithTable;
if (tableName != null) {
String tblName = tableName.toString();
columnWithTable = tblName + "." + ((SlotRef) predicate.getChild(0)).getColumnName();
columnWithTable = tblName + "." + ((SlotRef) predicate.getChildWithoutCast(0)).getColumnName();
} else {
columnWithTable = ((SlotRef) predicate.getChild(0)).getColumnName();
columnWithTable = ((SlotRef) predicate.getChildWithoutCast(0)).getColumnName();
}
slotNameToMergeExprsMap.computeIfAbsent(columnWithTable, key -> {
slotNameForMerge.add(columnWithTable);