[fix](nereids) fix some nereids bugs (#15714)

1. remove forcing nullable for slot on EmptySetNode.
2. order by xxx desc should use nulls last as default order.
3. don't create runtime filter if runtime filter mode is OFF.
4. group by constant value need check the corresponding expr shouldn't have any aggregation functions.
5. fix two left outer join reorder bug( A left join B left join C).
6. fix semi join and left outer join reorder bug.( A left join B semi join C ).
7. fix group by NULL bug.
8. change ceil and floor function to correct signature.
9. add literal comparasion for string and date type.
10. fix the getOnClauseUsedSlots method may not return valid value.
11. the tightness common type of string and date should be date.
12. the nullability of set operation node's result exprs is not set correctly.
13. Sort node should remove redundent ordering exprs.
This commit is contained in:
starocean999
2023-01-11 17:18:44 +08:00
committed by GitHub
parent d4e4e18b47
commit cfb110c905
33 changed files with 539 additions and 106 deletions

View File

@ -361,9 +361,6 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
List<Slot> output = emptyRelation.getOutput();
TupleDescriptor tupleDescriptor = generateTupleDesc(output, null, context);
for (int i = 0; i < output.size(); i++) {
SlotDescriptor slotDescriptor = tupleDescriptor.getSlots().get(i);
slotDescriptor.setIsNullable(true); // we should set to nullable, or else BE would core
Slot slot = output.get(i);
SlotRef slotRef = context.findSlotRef(slot.getExprId());
slotRef.setLabel(slot.getName());

View File

@ -86,6 +86,7 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob {
.add(topDownBatch(ImmutableList.of(new ExpressionNormalization(cascadesContext.getConnectContext()))))
.add(topDownBatch(ImmutableList.of(new ExpressionOptimization())))
.add(topDownBatch(ImmutableList.of(new ExtractSingleTableExpressionFromDisjunction())))
.add(topDownBatch(ImmutableList.of(new EliminateGroupByConstant())))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false))
.add(visitorJob(RuleType.INFER_PREDICATES, new InferPredicates()))
@ -107,7 +108,6 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob {
.add(topDownBatch(ImmutableList.of(new CountDistinctRewrite())))
// we need to execute this rule at the end of rewrite
// to avoid two consecutive same project appear when we do optimization.
.add(topDownBatch(ImmutableList.of(new EliminateGroupByConstant())))
.add(topDownBatch(ImmutableList.of(new EliminateOrderByConstant())))
.add(topDownBatch(ImmutableList.of(new EliminateUnnecessaryProject())))
.add(topDownBatch(ImmutableList.of(new SelectMaterializedIndexWithAggregate())))

View File

@ -1108,7 +1108,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
public OrderKey visitSortItem(SortItemContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {
boolean isAsc = ctx.DESC() == null;
boolean isNullFirst = ctx.LAST() == null;
boolean isNullFirst = ctx.FIRST() != null || (ctx.LAST() == null && isAsc);
Expression expression = typedVisit(ctx.expression());
return new OrderKey(expression, isAsc, isNullFirst);
});

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.processor.post;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TRuntimeFilterMode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
@ -57,7 +58,9 @@ public class PlanPostProcessors {
public List<PlanPostProcessor> getProcessors() {
// add processor if we need
Builder<PlanPostProcessor> builder = ImmutableList.builder();
if (cascadesContext.getConnectContext().getSessionVariable().isEnableNereidsRuntimeFilter()) {
if (cascadesContext.getConnectContext().getSessionVariable().isEnableNereidsRuntimeFilter()
&& !cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode()
.toUpperCase().equals(TRuntimeFilterMode.OFF.name())) {
builder.add(new RuntimeFilterGenerator());
if (ConnectContext.get().getSessionVariable().enableRuntimeFilterPrune) {
builder.add(new RuntimeFilterPruner());

View File

@ -104,5 +104,12 @@ public class CheckAnalysis implements AnalysisRuleFactory {
throw new AnalysisException(
"The query contains multi count distinct or sum distinct, each can't have multi columns");
}
Optional<Expression> expr = aggregate.getGroupByExpressions().stream()
.filter(expression -> expression.containsType(AggregateFunction.class)).findFirst();
if (expr.isPresent()) {
throw new AnalysisException(
"GROUP BY expression must not contain aggregate functions: "
+ expr.get().toSql());
}
}
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.rewrite.rules.FoldConstantRule;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -63,7 +64,7 @@ public class ResolveOrdinalInOrderByAndGroupBy implements AnalysisRuleFactory {
})
))
.add(RuleType.RESOLVE_ORDINAL_IN_GROUP_BY.build(
logicalAggregate().then(agg -> {
logicalAggregate().whenNot(agg -> agg.isOrdinalIsResolved()).then(agg -> {
List<NamedExpression> aggOutput = agg.getOutputExpressions();
List<Expression> groupByWithoutOrd = new ArrayList<>();
boolean ordExists = false;
@ -74,6 +75,9 @@ public class ResolveOrdinalInOrderByAndGroupBy implements AnalysisRuleFactory {
int ord = i.getIntValue();
checkOrd(ord, aggOutput.size());
Expression aggExpr = aggOutput.get(ord - 1);
if (aggExpr instanceof Alias) {
aggExpr = ((Alias) aggExpr).child();
}
groupByWithoutOrd.add(aggExpr);
ordExists = true;
} else {
@ -81,11 +85,11 @@ public class ResolveOrdinalInOrderByAndGroupBy implements AnalysisRuleFactory {
}
}
if (ordExists) {
return new LogicalAggregate(groupByWithoutOrd, agg.getOutputExpressions(), agg.child());
return new LogicalAggregate(groupByWithoutOrd, agg.getOutputExpressions(), true,
agg.child());
} else {
return agg;
}
}))).build();
}

View File

@ -31,6 +31,7 @@ import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.PlanUtils;
@ -75,26 +76,25 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
GroupPlan a = bottomJoin.left();
GroupPlan b = bottomJoin.right();
GroupPlan c = topJoin.right();
Set<Slot> bOutputSet = b.getOutputSet();
Set<Slot> aOutputSet = a.getOutputSet();
Set<Slot> cOutputSet = c.getOutputSet();
/* ********** Split projects ********** */
Map<Boolean, List<NamedExpression>> projectExprsMap = projects.stream()
.collect(Collectors.partitioningBy(projectExpr -> {
Set<Slot> usedSlots = projectExpr.collect(SlotReference.class::isInstance);
return bOutputSet.containsAll(usedSlots);
return aOutputSet.containsAll(usedSlots);
}));
List<NamedExpression> newLeftProjects = projectExprsMap.get(Boolean.FALSE);
List<NamedExpression> newRightProjects = projectExprsMap.get(Boolean.TRUE);
Set<ExprId> bExprIdSet = InnerJoinLAsscomProject.getExprIdSetForB(bottomJoin.right(),
newRightProjects);
List<NamedExpression> newLeftProjects = projectExprsMap.get(Boolean.TRUE);
List<NamedExpression> newRightProjects = projectExprsMap.get(Boolean.FALSE);
Set<ExprId> aExprIdSet = getExprIdSetForA(bottomJoin.left(),
newLeftProjects);
/* ********** split Conjuncts ********** */
Map<Boolean, List<Expression>> splitHashJoinConjuncts
= InnerJoinLAsscomProject.splitConjunctsWithAlias(
topJoin.getHashJoinConjuncts(), bottomJoin.getHashJoinConjuncts(), bExprIdSet);
List<Expression> newTopHashJoinConjuncts = splitHashJoinConjuncts.get(true);
Map<Boolean, List<Expression>> newHashJoinConjuncts
= createNewConjunctsWithAlias(
topJoin.getHashJoinConjuncts(), bottomJoin.getHashJoinConjuncts(), aExprIdSet);
List<Expression> newTopHashJoinConjuncts = newHashJoinConjuncts.get(true);
Preconditions.checkState(!newTopHashJoinConjuncts.isEmpty(),
"LAsscom newTopHashJoinConjuncts join can't empty");
// When newTopHashJoinConjuncts.size() != bottomJoin.getHashJoinConjuncts().size()
@ -103,25 +103,21 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
&& newTopHashJoinConjuncts.size() != bottomJoin.getHashJoinConjuncts().size()) {
return null;
}
List<Expression> newBottomHashJoinConjuncts = splitHashJoinConjuncts.get(false);
List<Expression> newBottomHashJoinConjuncts = newHashJoinConjuncts.get(false);
if (newBottomHashJoinConjuncts.size() == 0) {
return null;
}
Map<Boolean, List<Expression>> splitOtherJoinConjuncts
= InnerJoinLAsscomProject.splitConjunctsWithAlias(
Map<Boolean, List<Expression>> newOtherJoinConjuncts
= createNewConjunctsWithAlias(
topJoin.getOtherJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(),
bExprIdSet);
List<Expression> newTopOtherJoinConjuncts = splitOtherJoinConjuncts.get(true);
// When topJoin type differ from bottomJoin type (LOJ-inner or inner LOJ),
// we just can exchange topJoin and bottomJoin. like:
// Failed in: (A LOJ B on A.id = B.id) join C on c.id = A.id & c.id = B.id (top contain c.id = B.id)
// If type is same like LOJ(LOJ()), we can LAsscom.
if (topJoin.getJoinType() != bottomJoin.getJoinType()
&& newTopOtherJoinConjuncts.size() != bottomJoin.getOtherJoinConjuncts().size()) {
aExprIdSet);
List<Expression> newTopOtherJoinConjuncts = newOtherJoinConjuncts.get(true);
List<Expression> newBottomOtherJoinConjuncts = newOtherJoinConjuncts.get(false);
if (newBottomOtherJoinConjuncts.size() != topJoin.getOtherJoinConjuncts().size()
|| newTopOtherJoinConjuncts.size() != bottomJoin.getOtherJoinConjuncts().size()) {
return null;
}
List<Expression> newBottomOtherJoinConjuncts = splitOtherJoinConjuncts.get(false);
/* ********** replace Conjuncts by projects ********** */
Map<Slot, Slot> inputToOutput = new HashMap<>();
@ -158,10 +154,10 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
return usedSlotRefs.stream();
})
.filter(slot -> !cOutputSet.contains(slot))
.collect(Collectors.partitioningBy(slot -> bExprIdSet.contains(slot.getExprId()),
.collect(Collectors.partitioningBy(slot -> aExprIdSet.contains(slot.getExprId()),
Collectors.toSet()));
Set<Slot> aUsedSlots = abOnUsedSlots.get(Boolean.FALSE);
Set<Slot> bUsedSlots = abOnUsedSlots.get(Boolean.TRUE);
Set<Slot> aUsedSlots = abOnUsedSlots.get(Boolean.TRUE);
Set<Slot> bUsedSlots = abOnUsedSlots.get(Boolean.FALSE);
JoinUtils.addSlotsUsedByOn(bUsedSlots, newRightProjects);
JoinUtils.addSlotsUsedByOn(aUsedSlots, newLeftProjects);
@ -200,4 +196,35 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
return PlanUtils.project(new ArrayList<>(topJoin.getOutput()), newTopJoin).get();
}).toRule(RuleType.LOGICAL_OUTER_JOIN_LASSCOM_PROJECT);
}
private Map<Boolean, List<Expression>> createNewConjunctsWithAlias(List<Expression> topConjuncts,
List<Expression> bottomConjuncts, Set<ExprId> bExprIdSet) {
// if top join's conjuncts are all related to A, we can do reorder
Map<Boolean, List<Expression>> splitOn = new HashMap<>();
splitOn.put(true, new ArrayList<>());
if (topConjuncts.stream().allMatch(topHashOn -> {
Set<Slot> usedSlots = topHashOn.getInputSlots();
Set<ExprId> usedSlotsId = usedSlots.stream().map(NamedExpression::getExprId)
.collect(Collectors.toSet());
return ExpressionUtils.isIntersecting(bExprIdSet, usedSlotsId);
})) {
// do reorder, create new bottom join conjuncts
splitOn.put(false, new ArrayList<>(topConjuncts));
} else {
// can't reorder, return empty list
splitOn.put(false, new ArrayList<>());
}
List<Expression> newTopHashJoinConjuncts = splitOn.get(true);
newTopHashJoinConjuncts.addAll(bottomConjuncts);
return splitOn;
}
private Set<ExprId> getExprIdSetForA(GroupPlan a, List<NamedExpression> aProject) {
return Stream.concat(
a.getOutput().stream().map(NamedExpression::getExprId),
aProject.stream().map(NamedExpression::getExprId)).collect(Collectors.toSet());
}
}

View File

@ -54,9 +54,10 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalJoin(logicalJoin(), group())
.when(topJoin -> topJoin.getJoinType() == JoinType.LEFT_SEMI_JOIN
|| topJoin.getJoinType() == JoinType.LEFT_ANTI_JOIN
|| topJoin.getJoinType() == JoinType.NULL_AWARE_LEFT_ANTI_JOIN)
.when(topJoin -> (topJoin.getJoinType().isLeftSemiOrAntiJoin()
&& (topJoin.left().getJoinType().isInnerJoin()
|| topJoin.left().getJoinType().isLeftOuterJoin()
|| topJoin.left().getJoinType().isRightOuterJoin())))
.whenNot(topJoin -> topJoin.left().getJoinType().isSemiOrAntiJoin())
.when(this::conditionChecker)
.whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.left().hasJoinHint())
@ -75,6 +76,8 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
lasscom = ExpressionUtils.isIntersecting(usedSlot, aOutputSet) || lasscom;
}
JoinType newTopJoinType = JoinType.INNER_JOIN;
if (lasscom) {
/*
* topSemiJoin newTopJoin
@ -85,9 +88,10 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
*/
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
topSemiJoin.getJoinType(),
topSemiJoin.getHashJoinConjuncts(), topSemiJoin.getOtherJoinConjuncts(), JoinHint.NONE,
topSemiJoin.getHashJoinConjuncts(), topSemiJoin.getOtherJoinConjuncts(),
JoinHint.NONE,
a, c);
return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
return new LogicalJoin<>(newTopJoinType, bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinConjuncts(), JoinHint.NONE, newBottomSemiJoin, b);
} else {
/*
@ -102,7 +106,7 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
topSemiJoin.getHashJoinConjuncts(), topSemiJoin.getOtherJoinConjuncts(),
JoinHint.NONE,
b, c);
return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
return new LogicalJoin<>(newTopJoinType, bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinConjuncts(), JoinHint.NONE, a, newBottomSemiJoin);
}
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE);

View File

@ -56,9 +56,10 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
@Override
public Rule build() {
return logicalJoin(logicalProject(logicalJoin()), group())
.when(topJoin -> topJoin.getJoinType() == JoinType.LEFT_SEMI_JOIN
|| topJoin.getJoinType() == JoinType.LEFT_ANTI_JOIN
|| topJoin.getJoinType() == JoinType.NULL_AWARE_LEFT_ANTI_JOIN)
.when(topJoin -> (topJoin.getJoinType().isLeftSemiOrAntiJoin()
&& (topJoin.left().child().getJoinType().isInnerJoin()
|| topJoin.left().child().getJoinType().isLeftOuterJoin()
|| topJoin.left().child().getJoinType().isRightOuterJoin())))
.whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin())
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.when(join -> JoinReorderCommon.checkProject(join.left()))
@ -80,6 +81,8 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
lasscom = ExpressionUtils.isIntersecting(usedSlot, aOutputSet) || lasscom;
}
JoinType newTopJoinType = JoinType.INNER_JOIN;
if (lasscom) {
/*-
* topSemiJoin project
@ -94,8 +97,9 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
topSemiJoin.getOtherJoinConjuncts(), JoinHint.NONE, a, c);
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), JoinHint.NONE,
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(newTopJoinType,
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(),
JoinHint.NONE,
newBottomSemiJoin, b);
return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
@ -113,8 +117,9 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
topSemiJoin.getOtherJoinConjuncts(), JoinHint.NONE, b, c);
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), JoinHint.NONE,
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(newTopJoinType,
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(),
JoinHint.NONE,
a, newBottomSemiJoin);
return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);

View File

@ -76,8 +76,8 @@ public class ApplyPullFilterOnProjectUnderAgg extends OneRewriteRuleFactory {
LogicalProject newProject = new LogicalProject<>(newProjects, filter.child());
LogicalFilter newFilter = new LogicalFilter<>(filter.getConjuncts(), newProject);
LogicalAggregate newAgg = new LogicalAggregate<>(
agg.getGroupByExpressions(), agg.getOutputExpressions(), newFilter);
LogicalAggregate newAgg = new LogicalAggregate<>(agg.getGroupByExpressions(),
agg.getOutputExpressions(), agg.isOrdinalIsResolved(), newFilter);
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
apply.getCorrelationFilter(), apply.left(), newAgg);
}).toRule(RuleType.APPLY_PULL_FILTER_ON_PROJECT_UNDER_AGG);

View File

@ -59,7 +59,7 @@ public class EliminateGroupByConstant extends OneRewriteRuleFactory {
lit = expression;
}
}
if (slotGroupByExprs.isEmpty() && lit != null) {
if (slotGroupByExprs.isEmpty() && lit != null && aggregate.getAggregateFunctions().isEmpty()) {
slotGroupByExprs.add(lit);
}
return aggregate.withGroupByAndOutput(ImmutableList.copyOf(slotGroupByExprs), outputExprs);

View File

@ -21,6 +21,9 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.TinyIntType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
@ -133,7 +136,10 @@ public interface NormalizeToSlot {
}
Alias alias = new Alias(expression, expression.toSql());
return new NormalizeToSlotTriplet(expression, alias.toSlot(), alias);
SlotReference slot = (SlotReference) alias.toSlot();
// BE will create a nullable uint8 column to expand NullLiteral when necessary
return new NormalizeToSlotTriplet(expression, expression instanceof NullLiteral ? slot.withDataType(
TinyIntType.INSTANCE) : slot, alias);
}
}
}

View File

@ -24,7 +24,6 @@ import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSi
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
import com.google.common.base.Preconditions;
@ -39,7 +38,7 @@ public class Ceil extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable, ComputePrecisionForRound {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(DoubleType.INSTANCE)
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
);
/**

View File

@ -24,7 +24,6 @@ import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSi
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
import com.google.common.base.Preconditions;
@ -39,7 +38,7 @@ public class Floor extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable, ComputePrecisionForRound {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(DoubleType.INSTANCE)
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
);
/**

View File

@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.expressions.shape.LeafExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.StringType;
import org.apache.commons.lang3.StringUtils;
@ -161,16 +160,15 @@ public abstract class Literal extends Expression implements LeafExpression, Comp
return Float.compare((float) getValue(), (float) other.getValue());
} else if (type.isDoubleType()) {
return Double.compare((double) getValue(), (double) other.getValue());
} else if (type.isDecimalV2Type()) {
return Long.compare((Long) getValue(), (Long) other.getValue());
} else if (type.isDateLikeType()) {
// todo process date
return Long.compare((Long) getValue(), (Long) other.getValue());
} else if (type.isDecimalV2Type()) {
return ((BigDecimal) getValue()).compareTo((BigDecimal) other.getValue());
} else if (type instanceof StringType) {
} else if (type.isStringLikeType()) {
return StringUtils.compare((String) getValue(), (String) other.getValue());
} else {
throw new RuntimeException(String.format("Literal {} is not supported!", type.toString()));
}
return -1;
}
@Override

View File

@ -60,6 +60,8 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
// When there are grouping sets/rollup/cube, LogicalAgg is generated by LogicalRepeat.
private final Optional<LogicalRepeat> sourceRepeat;
private final boolean ordinalIsResolved;
/**
* Desc: Constructor for LogicalAggregate.
*/
@ -71,6 +73,12 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
false, Optional.empty(), child);
}
public LogicalAggregate(List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions, boolean ordinalIsResolved, CHILD_TYPE child) {
this(groupByExpressions, outputExpressions, false, ordinalIsResolved, Optional.empty(),
Optional.empty(), Optional.empty(), child);
}
/**
* Desc: Constructor for LogicalAggregate.
* Generated from LogicalRepeat.
@ -89,7 +97,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
boolean normalized,
Optional<LogicalRepeat> sourceRepeat,
CHILD_TYPE child) {
this(groupByExpressions, outputExpressions, normalized, sourceRepeat,
this(groupByExpressions, outputExpressions, normalized, false, sourceRepeat,
Optional.empty(), Optional.empty(), child);
}
@ -100,6 +108,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
boolean normalized,
boolean ordinalIsResolved,
Optional<LogicalRepeat> sourceRepeat,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
@ -108,6 +117,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
this.groupByExpressions = ImmutableList.copyOf(groupByExpressions);
this.outputExpressions = ImmutableList.copyOf(outputExpressions);
this.normalized = normalized;
this.ordinalIsResolved = ordinalIsResolved;
this.sourceRepeat = Objects.requireNonNull(sourceRepeat, "sourceRepeat cannot be null");
}
@ -160,6 +170,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
return normalized;
}
public boolean isOrdinalIsResolved() {
return ordinalIsResolved;
}
/**
* Determine the equality with another plan
*/
@ -174,48 +188,51 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
return Objects.equals(groupByExpressions, that.groupByExpressions)
&& Objects.equals(outputExpressions, that.outputExpressions)
&& normalized == that.normalized
&& ordinalIsResolved == that.ordinalIsResolved
&& Objects.equals(sourceRepeat, that.sourceRepeat);
}
@Override
public int hashCode() {
return Objects.hash(groupByExpressions, outputExpressions, normalized, sourceRepeat);
return Objects.hash(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, sourceRepeat);
}
@Override
public LogicalAggregate<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
normalized, sourceRepeat, children.get(0));
normalized, ordinalIsResolved, sourceRepeat, Optional.empty(), Optional.empty(), children.get(0));
}
@Override
public LogicalAggregate<Plan> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
normalized, sourceRepeat, groupExpression, Optional.of(getLogicalProperties()), children.get(0));
normalized, ordinalIsResolved, sourceRepeat, groupExpression, Optional.of(getLogicalProperties()),
children.get(0));
}
@Override
public LogicalAggregate<Plan> withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
normalized, sourceRepeat,
normalized, ordinalIsResolved, sourceRepeat,
Optional.empty(), logicalProperties, children.get(0));
}
public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> groupByExprList,
List<NamedExpression> outputExpressionList) {
return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, sourceRepeat, child());
return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved,
sourceRepeat, Optional.empty(), Optional.empty(), child());
}
@Override
public LogicalAggregate<CHILD_TYPE> withAggOutput(List<NamedExpression> newOutput) {
return new LogicalAggregate<>(groupByExpressions, newOutput, normalized,
return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, ordinalIsResolved,
sourceRepeat, Optional.empty(), Optional.empty(), child());
}
public LogicalAggregate<Plan> withNormalized(List<Expression> normalizedGroupBy,
List<NamedExpression> normalizedOutput, Plan normalizedChild) {
return new LogicalAggregate<>(normalizedGroupBy, normalizedOutput, true,
return new LogicalAggregate<>(normalizedGroupBy, normalizedOutput, true, ordinalIsResolved,
sourceRepeat, Optional.empty(),
Optional.empty(), normalizedChild);
}

View File

@ -42,6 +42,7 @@ import com.google.common.collect.Lists;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@ -151,11 +152,15 @@ public class JoinUtils {
for (Expression expr : join.getHashJoinConjuncts()) {
EqualTo equalTo = (EqualTo) expr;
if (!(equalTo.left() instanceof Slot) || !(equalTo.right() instanceof Slot)) {
// TODO: we could meet a = cast(b as xxx) here, need fix normalize join hash equals future
Optional<Slot> leftSlot = ExpressionUtils.extractSlotOrCastOnSlot(equalTo.left());
Optional<Slot> rightSlot = ExpressionUtils.extractSlotOrCastOnSlot(equalTo.right());
if (!leftSlot.isPresent() || !rightSlot.isPresent()) {
continue;
}
ExprId leftExprId = ((Slot) equalTo.left()).getExprId();
ExprId rightExprId = ((Slot) equalTo.right()).getExprId();
ExprId leftExprId = leftSlot.get().getExprId();
ExprId rightExprId = rightSlot.get().getExprId();
if (checker.isCoveredByLeftSlots(leftExprId)
&& checker.isCoveredByRightSlots(rightExprId)) {

View File

@ -213,6 +213,10 @@ public class TypeCoercionUtils {
}
} else if (left instanceof CharacterType && right instanceof CharacterType) {
tightestCommonType = CharacterType.widerCharacterType((CharacterType) left, (CharacterType) right);
} else if (left instanceof CharacterType && right instanceof DateLikeType
|| left instanceof DateLikeType && right instanceof CharacterType) {
// TODO: need check implicitCastMap to keep the behavior consistent with old optimizer
tightestCommonType = right;
} else if (left instanceof CharacterType || right instanceof CharacterType) {
tightestCommonType = StringType.INSTANCE;
} else if (left instanceof DecimalV2Type && right instanceof IntegralType) {

View File

@ -501,6 +501,9 @@ public abstract class SetOperationNode extends PlanNode {
for (int j = 0; j < exprList.size(); ++j) {
if (resultExprSlots.get(j).isMaterialized()) {
newExprList.add(exprList.get(j));
// TODO: reconsider this, we may change nullable info in previous nereids rules not here.
resultExprSlots.get(j)
.setIsNullable(resultExprSlots.get(j).getIsNullable() || exprList.get(j).isNullable());
}
}
materializedResultExprLists.add(newExprList);

View File

@ -320,7 +320,13 @@ public class SortNode extends PlanNode {
*/
public void finalizeForNereids(TupleDescriptor tupleDescriptor,
List<Expr> outputList, List<Expr> orderingExpr) {
resolvedTupleExprs = Lists.newArrayList(orderingExpr);
resolvedTupleExprs = Lists.newArrayList();
// TODO: should fix the duplicate order by exprs in nereids code later
for (Expr order : orderingExpr) {
if (!resolvedTupleExprs.contains(order)) {
resolvedTupleExprs.add(order);
}
}
for (Expr output : outputList) {
if (!resolvedTupleExprs.contains(output)) {
resolvedTupleExprs.add(output);

View File

@ -75,7 +75,9 @@ class OuterJoinLAsscomProjectTest implements PatternMatchSupported {
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.printlnOrigin()
.applyExploration(OuterJoinLAsscomProject.INSTANCE.build())
.printlnExploration()
.matchesExploration(
logicalJoin(
logicalProject(
@ -103,20 +105,9 @@ class OuterJoinLAsscomProjectTest implements PatternMatchSupported {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(OuterJoinLAsscomProject.INSTANCE.build())
.printlnOrigin()
.printlnExploration()
.matchesExploration(
logicalJoin(
logicalProject(
logicalJoin(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")),
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
).when(join -> join.getHashJoinConjuncts().size() == 1)
).when(project -> project.getProjects().size() == 3), // t1.id Add t3.id, t3.name
logicalProject(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))
).when(project -> project.getProjects().size() == 1)
).when(join -> join.getHashJoinConjuncts().size() == 2)
);
.checkMemo(memo -> {
Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size());
});
}
@Test
@ -156,22 +147,10 @@ class OuterJoinLAsscomProjectTest implements PatternMatchSupported {
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.printlnOrigin()
.applyExploration(OuterJoinLAsscomProject.INSTANCE.build())
.matchesExploration(
logicalJoin(
logicalProject(
logicalJoin(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")),
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
).when(join -> join.getOtherJoinConjuncts().size() == 1
&& join.getHashJoinConjuncts().size() == 1)
),
logicalProject(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))
)
).when(join -> join.getOtherJoinConjuncts().size() == 2
&& join.getHashJoinConjuncts().size() == 2)
)
.printlnExploration();
.checkMemo(memo -> {
Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size());
});
}
}