diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java index 48be0c6978..99c2eb7604 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/NereidsAnalyzer.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.jobs.batch.AdjustAggregateNullableForEmptySetJob import org.apache.doris.nereids.jobs.batch.AnalyzeRulesJob; import org.apache.doris.nereids.jobs.batch.AnalyzeSubqueryRulesJob; import org.apache.doris.nereids.jobs.batch.CheckAnalysisJob; -import org.apache.doris.nereids.jobs.batch.TypeCoercionJob; import java.util.Objects; import java.util.Optional; @@ -52,7 +51,6 @@ public class NereidsAnalyzer { new AnalyzeRulesJob(cascadesContext, outerScope).execute(); new AnalyzeSubqueryRulesJob(cascadesContext).execute(); new AdjustAggregateNullableForEmptySetJob(cascadesContext).execute(); - new TypeCoercionJob(cascadesContext).execute(); // check whether analyze result is meaningful new CheckAnalysisJob(cascadesContext).execute(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java index 0d7443110f..c68e096baf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java @@ -20,9 +20,8 @@ package org.apache.doris.nereids.jobs.batch; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.analyzer.Scope; import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount; -import org.apache.doris.nereids.rules.analysis.BindFunction; +import org.apache.doris.nereids.rules.analysis.BindExpression; import org.apache.doris.nereids.rules.analysis.BindRelation; -import org.apache.doris.nereids.rules.analysis.BindSlotReference; import org.apache.doris.nereids.rules.analysis.CheckPolicy; import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots; import org.apache.doris.nereids.rules.analysis.NormalizeRepeat; @@ -32,9 +31,6 @@ import org.apache.doris.nereids.rules.analysis.RegisterCTE; import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput; import org.apache.doris.nereids.rules.analysis.ResolveOrdinalInOrderByAndGroupBy; import org.apache.doris.nereids.rules.analysis.UserAuthentication; -import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization; -import org.apache.doris.nereids.rules.expression.rewrite.rules.CharacterLiteralTypeCoercion; -import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion; import org.apache.doris.nereids.rules.rewrite.logical.HideOneRowRelationUnderUnion; import com.google.common.collect.ImmutableList; @@ -61,8 +57,7 @@ public class AnalyzeRulesJob extends BatchRulesJob { new BindRelation(), new CheckPolicy(), new UserAuthentication(), - new BindSlotReference(scope), - new BindFunction(), + new BindExpression(scope), new ProjectToGlobalAggregate(), // this rule check's the logicalProject node's isDisinct property // and replace the logicalProject node with a LogicalAggregate node @@ -73,9 +68,7 @@ public class AnalyzeRulesJob extends BatchRulesJob { new AvgDistinctToSumDivCount(), new ResolveOrdinalInOrderByAndGroupBy(), new ReplaceExpressionByChildOutput(), - new HideOneRowRelationUnderUnion(), - new ExpressionNormalization(cascadesContext.getConnectContext(), - ImmutableList.of(CharacterLiteralTypeCoercion.INSTANCE, TypeCoercion.INSTANCE)) + new HideOneRowRelationUnderUnion() ), topDownBatch( new FillUpMissingSlots(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java index 1d656fa9b7..e7e5276436 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java @@ -45,7 +45,6 @@ import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; /** @@ -89,31 +88,27 @@ public class RuntimeFilterGenerator extends PlanPostProcessor { List legalTypes = Arrays.stream(TRuntimeFilterType.values()) .filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0) .collect(Collectors.toList()); - AtomicInteger cnt = new AtomicInteger(); - join.getHashJoinConjuncts().stream() - .map(EqualTo.class::cast) - // TODO: some complex situation cannot be handled now, see testPushDownThroughJoin. - // TODO: we will support it in later version. - .forEach(expr -> legalTypes.forEach(type -> { - Pair normalizedChildren = checkAndMaybeSwapChild(expr, join); - // aliasTransMap doesn't contain the key, means that the path from the olap scan to the join - // contains join with denied join type. for example: a left join b on a.id = b.id - if (normalizedChildren == null - || !aliasTransferMap.containsKey((Slot) normalizedChildren.first)) { - return; - } - Pair slots = Pair.of( - aliasTransferMap.get((Slot) normalizedChildren.first).second.toSlot(), - ((Slot) normalizedChildren.second)); - RuntimeFilter filter = new RuntimeFilter(generator.getNextId(), - slots.second, slots.first, type, - cnt.getAndIncrement(), join); - ctx.addJoinToTargetMap(join, slots.first.getExprId()); - ctx.setTargetExprIdToFilter(slots.first.getExprId(), filter); - ctx.setTargetsOnScanNode( - aliasTransferMap.get((Slot) normalizedChildren.first).first, - slots.first); - })); + // TODO: some complex situation cannot be handled now, see testPushDownThroughJoin. + // TODO: we will support it in later version. + for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) { + EqualTo expr = (EqualTo) join.getHashJoinConjuncts().get(i); + for (TRuntimeFilterType type : legalTypes) { + Pair normalizedChildren = checkAndMaybeSwapChild(expr, join); + // aliasTransMap doesn't contain the key, means that the path from the olap scan to the join + // contains join with denied join type. for example: a left join b on a.id = b.id + if (normalizedChildren == null || !aliasTransferMap.containsKey((Slot) normalizedChildren.first)) { + continue; + } + Pair slots = Pair.of( + aliasTransferMap.get((Slot) normalizedChildren.first).second.toSlot(), + ((Slot) normalizedChildren.second)); + RuntimeFilter filter = new RuntimeFilter(generator.getNextId(), + slots.second, slots.first, type, i, join); + ctx.addJoinToTargetMap(join, slots.first.getExprId()); + ctx.setTargetExprIdToFilter(slots.first.getExprId(), filter); + ctx.setTargetsOnScanNode(aliasTransferMap.get((Slot) normalizedChildren.first).first, slots.first); + } + } } return join; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java similarity index 72% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index 18f9203cc3..6d56a87579 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -17,17 +17,20 @@ package org.apache.doris.nereids.rules.analysis; -import org.apache.doris.catalog.BuiltinAggregateFunctions; +import org.apache.doris.catalog.Env; +import org.apache.doris.catalog.FunctionRegistry; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.analyzer.Scope; import org.apache.doris.nereids.analyzer.UnboundFunction; import org.apache.doris.nereids.analyzer.UnboundOneRowRelation; import org.apache.doris.nereids.analyzer.UnboundSlot; +import org.apache.doris.nereids.analyzer.UnboundTVFRelation; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.analysis.BindFunction.FunctionBinder; +import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion; import org.apache.doris.nereids.trees.UnaryNode; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.BoundStar; @@ -36,7 +39,13 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.TVFProperties; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; +import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; @@ -57,6 +66,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation; import org.apache.doris.nereids.trees.plans.logical.UsingJoin; import com.google.common.base.Preconditions; @@ -81,11 +91,11 @@ import java.util.stream.Stream; * BindSlotReference. */ @SuppressWarnings("OptionalUsedAsFieldOrParameterType") -public class BindSlotReference implements AnalysisRuleFactory { +public class BindExpression implements AnalysisRuleFactory { private final Optional outerScope; - public BindSlotReference(Optional outerScope) { + public BindExpression(Optional outerScope) { this.outerScope = Objects.requireNonNull(outerScope, "outerScope cannot be null"); } @@ -104,18 +114,23 @@ public class BindSlotReference implements AnalysisRuleFactory { logicalProject().when(Plan::canBind).thenApply(ctx -> { LogicalProject project = ctx.root; List boundProjections = - bind(project.getProjects(), project.children(), ctx.cascadesContext); - List boundExceptions = bind(project.getExcepts(), project.children(), + bindSlot(project.getProjects(), project.children(), ctx.cascadesContext); + List boundExceptions = bindSlot(project.getExcepts(), project.children(), ctx.cascadesContext); boundProjections = flatBoundStar(boundProjections, boundExceptions); + boundProjections = boundProjections.stream() + .map(expr -> bindFunction(expr, ctx.cascadesContext)) + .collect(ImmutableList.toImmutableList()); return new LogicalProject<>(boundProjections, project.child(), project.isDistinct()); }) ), RuleType.BINDING_FILTER_SLOT.build( logicalFilter().when(Plan::canBind).thenApply(ctx -> { LogicalFilter filter = ctx.root; - Set boundConjuncts - = bind(filter.getConjuncts(), filter.children(), ctx.cascadesContext); + Set boundConjuncts = filter.getConjuncts().stream() + .map(expr -> bindSlot(expr, filter.children(), ctx.cascadesContext)) + .map(expr -> bindFunction(expr, ctx.cascadesContext)) + .collect(Collectors.toSet()); return new LogicalFilter<>(boundConjuncts, filter.child()); }) ), @@ -141,7 +156,7 @@ public class BindSlotReference implements AnalysisRuleFactory { .peek(s -> slotNames.add(s.getName())) .collect(Collectors.toList())); for (Expression unboundSlot : unboundSlots) { - Expression expression = new Binder(scope, ctx.cascadesContext).bind(unboundSlot); + Expression expression = new SlotBinder(scope, ctx.cascadesContext).bind(unboundSlot); leftSlots.add(expression); } slotNames.clear(); @@ -151,7 +166,7 @@ public class BindSlotReference implements AnalysisRuleFactory { .collect(Collectors.toList())); List rightSlots = new ArrayList<>(); for (Expression unboundSlot : unboundSlots) { - Expression expression = new Binder(scope, ctx.cascadesContext).bind(unboundSlot); + Expression expression = new SlotBinder(scope, ctx.cascadesContext).bind(unboundSlot); rightSlots.add(expression); } int size = leftSlots.size(); @@ -166,10 +181,12 @@ public class BindSlotReference implements AnalysisRuleFactory { logicalJoin().when(Plan::canBind).thenApply(ctx -> { LogicalJoin join = ctx.root; List cond = join.getOtherJoinConjuncts().stream() - .map(expr -> bind(expr, join.children(), ctx.cascadesContext)) + .map(expr -> bindSlot(expr, join.children(), ctx.cascadesContext)) + .map(expr -> bindFunction(expr, ctx.cascadesContext)) .collect(Collectors.toList()); List hashJoinConjuncts = join.getHashJoinConjuncts().stream() - .map(expr -> bind(expr, join.children(), ctx.cascadesContext)) + .map(expr -> bindSlot(expr, join.children(), ctx.cascadesContext)) + .map(expr -> bindFunction(expr, ctx.cascadesContext)) .collect(Collectors.toList()); return new LogicalJoin<>(join.getJoinType(), hashJoinConjuncts, cond, join.getHint(), join.left(), join.right()); @@ -178,8 +195,10 @@ public class BindSlotReference implements AnalysisRuleFactory { RuleType.BINDING_AGGREGATE_SLOT.build( logicalAggregate().when(Plan::canBind).thenApply(ctx -> { LogicalAggregate agg = ctx.root; - List output = - bind(agg.getOutputExpressions(), agg.children(), ctx.cascadesContext); + List output = agg.getOutputExpressions().stream() + .map(expr -> bindSlot(expr, agg.children(), ctx.cascadesContext)) + .map(expr -> bindFunction(expr, ctx.cascadesContext)) + .collect(ImmutableList.toImmutableList()); // The columns referenced in group by are first obtained from the child's output, // and then from the node's output @@ -235,17 +254,8 @@ public class BindSlotReference implements AnalysisRuleFactory { .filter(ne -> ne instanceof Alias) .map(Alias.class::cast) // agg function cannot be bound with group_by_key - // TODO(morrySnow): after bind function moved here, - // we could just use instanceof AggregateFunction - .filter(alias -> !alias.child().anyMatch(expr -> { - if (expr instanceof UnboundFunction) { - UnboundFunction unboundFunction = (UnboundFunction) expr; - return BuiltinAggregateFunctions.INSTANCE.aggFuncNames.contains( - unboundFunction.getName().toLowerCase()); - } - return false; - } - )) + .filter(alias -> !alias.child() + .anyMatch(expr -> expr instanceof AggregateFunction)) .forEach(alias -> childOutputsToExpr.putIfAbsent(alias.getName(), alias.child())); List replacedGroupBy = agg.getGroupByExpressions().stream() @@ -280,9 +290,10 @@ public class BindSlotReference implements AnalysisRuleFactory { .collect(Collectors.toSet()); boundSlots.addAll(outputSlots); - Binder binder = new Binder(toScope(Lists.newArrayList(boundSlots)), ctx.cascadesContext); - Binder childBinder - = new Binder(toScope(new ArrayList<>(agg.child().getOutputSet())), ctx.cascadesContext); + SlotBinder binder = new SlotBinder( + toScope(Lists.newArrayList(boundSlots)), ctx.cascadesContext); + SlotBinder childBinder = new SlotBinder( + toScope(new ArrayList<>(agg.child().getOutputSet())), ctx.cascadesContext); List groupBy = replacedGroupBy.stream() .map(expression -> { @@ -305,15 +316,19 @@ public class BindSlotReference implements AnalysisRuleFactory { if (hasUnBound.test(groupBy)) { throw new AnalysisException("cannot bind GROUP BY KEY: " + unboundGroupBys.get(0).toSql()); } + groupBy = groupBy.stream() + .map(expr -> bindFunction(expr, ctx.cascadesContext)) + .collect(ImmutableList.toImmutableList()); return agg.withGroupByAndOutput(groupBy, output); }) ), RuleType.BINDING_REPEAT_SLOT.build( logicalRepeat().when(Plan::canBind).thenApply(ctx -> { LogicalRepeat repeat = ctx.root; - - List output = - bind(repeat.getOutputExpressions(), repeat.children(), ctx.cascadesContext); + List output = repeat.getOutputExpressions().stream() + .map(expr -> bindSlot(expr, repeat.children(), ctx.cascadesContext)) + .map(expr -> bindFunction(expr, ctx.cascadesContext)) + .collect(ImmutableList.toImmutableList()); // The columns referenced in group by are first obtained from the child's output, // and then from the node's output @@ -343,7 +358,10 @@ public class BindSlotReference implements AnalysisRuleFactory { List> groupingSets = replacedGroupingSets .stream() - .map(groupingSet -> bind(groupingSet, repeat.children(), ctx.cascadesContext)) + .map(groupingSet -> groupingSet.stream() + .map(expr -> bindSlot(expr, repeat.children(), ctx.cascadesContext)) + .map(expr -> bindFunction(expr, ctx.cascadesContext)) + .collect(ImmutableList.toImmutableList())) .collect(ImmutableList.toImmutableList()); List newOutput = adjustNullableForRepeat(groupingSets, output); return repeat.withGroupSetsAndOutput(groupingSets, newOutput); @@ -383,7 +401,8 @@ public class BindSlotReference implements AnalysisRuleFactory { List sortItemList = sort.getOrderKeys() .stream() .map(orderKey -> { - Expression item = bind(orderKey.getExpr(), sort.child(), ctx.cascadesContext); + Expression item = bindSlot(orderKey.getExpr(), sort.child(), ctx.cascadesContext); + item = bindFunction(item, ctx.cascadesContext); return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst()); }).collect(Collectors.toList()); return new LogicalSort<>(sortItemList, sort.child()); @@ -393,11 +412,12 @@ public class BindSlotReference implements AnalysisRuleFactory { logicalHaving(aggregate()).when(Plan::canBind).thenApply(ctx -> { LogicalHaving> having = ctx.root; Plan childPlan = having.child(); - Set boundConjuncts = having.getConjuncts().stream().map( - expr -> { - expr = bind(expr, childPlan.children(), ctx.cascadesContext); - return bind(expr, childPlan, ctx.cascadesContext); + Set boundConjuncts = having.getConjuncts().stream() + .map(expr -> { + expr = bindSlot(expr, childPlan.children(), ctx.cascadesContext); + return bindSlot(expr, childPlan, ctx.cascadesContext); }) + .map(expr -> bindFunction(expr, ctx.cascadesContext)) .collect(Collectors.toSet()); return new LogicalHaving<>(boundConjuncts, having.child()); }) @@ -406,11 +426,12 @@ public class BindSlotReference implements AnalysisRuleFactory { logicalHaving(any()).when(Plan::canBind).thenApply(ctx -> { LogicalHaving having = ctx.root; Plan childPlan = having.child(); - Set boundConjuncts = having.getConjuncts().stream().map( - expr -> { - expr = bind(expr, childPlan, ctx.cascadesContext); - return bind(expr, childPlan.children(), ctx.cascadesContext); - }) + Set boundConjuncts = having.getConjuncts().stream() + .map(expr -> { + expr = bindSlot(expr, childPlan, ctx.cascadesContext); + return bindSlot(expr, childPlan.children(), ctx.cascadesContext); + }) + .map(expr -> bindFunction(expr, ctx.cascadesContext)) .collect(Collectors.toSet()); return new LogicalHaving<>(boundConjuncts, having.child()); }) @@ -421,7 +442,8 @@ public class BindSlotReference implements AnalysisRuleFactory { UnboundOneRowRelation oneRowRelation = ctx.root; List projects = oneRowRelation.getProjects() .stream() - .map(project -> bind(project, ImmutableList.of(), ctx.cascadesContext)) + .map(project -> bindSlot(project, ImmutableList.of(), ctx.cascadesContext)) + .map(project -> bindFunction(project, ctx.cascadesContext)) .collect(Collectors.toList()); return new LogicalOneRowRelation(projects); }) @@ -450,26 +472,31 @@ public class BindSlotReference implements AnalysisRuleFactory { }) ), RuleType.BINDING_GENERATE_SLOT.build( - logicalGenerate().when(Plan::canBind).thenApply(ctx -> { - LogicalGenerate generate = ctx.root; - List boundSlotGenerators - = bind(generate.getGenerators(), generate.children(), ctx.cascadesContext); - List boundFunctionGenerators = boundSlotGenerators.stream() - .map(f -> FunctionBinder.INSTANCE.bindTableGeneratingFunction( - (UnboundFunction) f, ctx.statementContext)) - .collect(Collectors.toList()); - ImmutableList.Builder slotBuilder = ImmutableList.builder(); - for (int i = 0; i < generate.getGeneratorOutput().size(); i++) { - Function generator = boundFunctionGenerators.get(i); - UnboundSlot slot = (UnboundSlot) generate.getGeneratorOutput().get(i); - Preconditions.checkState(slot.getNameParts().size() == 2, - "the size of nameParts of UnboundSlot in LogicalGenerate must be 2."); - Slot boundSlot = new SlotReference(slot.getNameParts().get(1), generator.getDataType(), - generator.nullable(), ImmutableList.of(slot.getNameParts().get(0))); - slotBuilder.add(boundSlot); - } - return new LogicalGenerate<>(boundFunctionGenerators, slotBuilder.build(), generate.child()); - }) + logicalGenerate().when(Plan::canBind).thenApply(ctx -> { + LogicalGenerate generate = ctx.root; + List boundSlotGenerators + = bindSlot(generate.getGenerators(), generate.children(), ctx.cascadesContext); + List boundFunctionGenerators = boundSlotGenerators.stream() + .map(f -> bindTableGeneratingFunction((UnboundFunction) f, ctx.cascadesContext)) + .collect(Collectors.toList()); + ImmutableList.Builder slotBuilder = ImmutableList.builder(); + for (int i = 0; i < generate.getGeneratorOutput().size(); i++) { + Function generator = boundFunctionGenerators.get(i); + UnboundSlot slot = (UnboundSlot) generate.getGeneratorOutput().get(i); + Preconditions.checkState(slot.getNameParts().size() == 2, + "the size of nameParts of UnboundSlot in LogicalGenerate must be 2."); + Slot boundSlot = new SlotReference(slot.getNameParts().get(1), generator.getDataType(), + generator.nullable(), ImmutableList.of(slot.getNameParts().get(0))); + slotBuilder.add(boundSlot); + } + return new LogicalGenerate<>(boundFunctionGenerators, slotBuilder.build(), generate.child()); + }) + ), + RuleType.BINDING_UNBOUND_TVF_RELATION_FUNCTION.build( + unboundTVFRelation().thenApply(ctx -> { + UnboundTVFRelation relation = ctx.root; + return bindTableValuedFunction(relation, ctx.statementContext); + }) ), // when child update, we need update current plan's logical properties, @@ -502,8 +529,9 @@ public class BindSlotReference implements AnalysisRuleFactory { List sortItemList = sort.getOrderKeys() .stream() .map(orderKey -> { - Expression item = bind(orderKey.getExpr(), plan, ctx); - item = bind(item, plan.children(), ctx); + Expression item = bindSlot(orderKey.getExpr(), plan, ctx); + item = bindSlot(item, plan.children(), ctx); + item = bindFunction(item, ctx); return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst()); }).collect(Collectors.toList()); return new LogicalSort<>(sortItemList, sort.child()); @@ -525,29 +553,36 @@ public class BindSlotReference implements AnalysisRuleFactory { .collect(ImmutableList.toImmutableList()); } - private List bind(List exprList, List inputs, CascadesContext cascadesContext) { + private List bindSlot( + List exprList, List inputs, CascadesContext cascadesContext) { return exprList.stream() - .map(expr -> bind(expr, inputs, cascadesContext)) + .map(expr -> bindSlot(expr, inputs, cascadesContext)) .collect(Collectors.toList()); } - private Set bind(Set exprList, List inputs, CascadesContext cascadesContext) { + private Set bindSlot( + Set exprList, List inputs, CascadesContext cascadesContext) { return exprList.stream() - .map(expr -> bind(expr, inputs, cascadesContext)) + .map(expr -> bindSlot(expr, inputs, cascadesContext)) .collect(Collectors.toSet()); } @SuppressWarnings("unchecked") - private E bind(E expr, Plan input, CascadesContext cascadesContext) { - return (E) new Binder(toScope(input.getOutput()), cascadesContext).bind(expr); + private E bindSlot(E expr, Plan input, CascadesContext cascadesContext) { + return (E) new SlotBinder(toScope(input.getOutput()), cascadesContext).bind(expr); } @SuppressWarnings("unchecked") - private E bind(E expr, List inputs, CascadesContext cascadesContext) { + private E bindSlot(E expr, List inputs, CascadesContext cascadesContext) { List boundedSlots = inputs.stream() .flatMap(input -> input.getOutput().stream()) .collect(Collectors.toList()); - return (E) new Binder(toScope(boundedSlots), cascadesContext).bind(expr); + return (E) new SlotBinder(toScope(boundedSlots), cascadesContext).bind(expr); + } + + @SuppressWarnings("unchecked") + private E bindFunction(E expr, CascadesContext cascadesContext) { + return (E) FunctionBinder.INSTANCE.bind(expr, cascadesContext); } /** @@ -578,4 +613,36 @@ public class BindSlotReference implements AnalysisRuleFactory { return slotReference; } } + + private LogicalTVFRelation bindTableValuedFunction(UnboundTVFRelation unboundTVFRelation, + StatementContext statementContext) { + Env env = statementContext.getConnectContext().getEnv(); + FunctionRegistry functionRegistry = env.getFunctionRegistry(); + + String functionName = unboundTVFRelation.getFunctionName(); + TVFProperties arguments = unboundTVFRelation.getProperties(); + FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, arguments); + BoundFunction function = functionBuilder.build(functionName, arguments); + if (!(function instanceof TableValuedFunction)) { + throw new AnalysisException(function.toSql() + " is not a TableValuedFunction"); + } + return new LogicalTVFRelation(unboundTVFRelation.getId(), (TableValuedFunction) function); + } + + private BoundFunction bindTableGeneratingFunction(UnboundFunction unboundFunction, + CascadesContext cascadesContext) { + List boundArguments = unboundFunction.getArguments().stream() + .map(e -> bindFunction(e, cascadesContext)) + .collect(Collectors.toList()); + FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry(); + + String functionName = unboundFunction.getName(); + FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, boundArguments); + BoundFunction function = functionBuilder.build(functionName, boundArguments); + if (!(function instanceof TableGeneratingFunction)) { + throw new AnalysisException(function.toSql() + " is not a TableGeneratingFunction"); + } + function = (BoundFunction) TypeCoercion.INSTANCE.rewrite(function, null); + return function; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java deleted file mode 100644 index 0bb4e17296..0000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java +++ /dev/null @@ -1,269 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.analysis; - -import org.apache.doris.analysis.ArithmeticExpr.Operator; -import org.apache.doris.catalog.Env; -import org.apache.doris.catalog.FunctionRegistry; -import org.apache.doris.nereids.StatementContext; -import org.apache.doris.nereids.analyzer.UnboundFunction; -import org.apache.doris.nereids.analyzer.UnboundTVFRelation; -import org.apache.doris.nereids.exceptions.AnalysisException; -import org.apache.doris.nereids.jobs.batch.CheckLegalityBeforeTypeCoercion; -import org.apache.doris.nereids.properties.OrderKey; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.rewrite.rules.CharacterLiteralTypeCoercion; -import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.TVFProperties; -import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; -import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; -import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; -import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; -import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; -import org.apache.doris.nereids.trees.plans.GroupPlan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; -import org.apache.doris.nereids.trees.plans.logical.LogicalSort; -import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation; -import org.apache.doris.qe.ConnectContext; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; - -import java.util.List; -import java.util.Locale; -import java.util.Set; - -/** - * BindFunction. - */ -public class BindFunction implements AnalysisRuleFactory { - @Override - public List buildRules() { - return ImmutableList.of( - RuleType.BINDING_ONE_ROW_RELATION_FUNCTION.build( - logicalOneRowRelation().thenApply(ctx -> { - LogicalOneRowRelation oneRowRelation = ctx.root; - List projects = oneRowRelation.getProjects(); - List boundProjects = bindAndTypeCoercion(projects, ctx.connectContext); - if (projects.equals(boundProjects)) { - return oneRowRelation; - } - return new LogicalOneRowRelation(boundProjects); - }) - ), - RuleType.BINDING_PROJECT_FUNCTION.build( - logicalProject().thenApply(ctx -> { - LogicalProject project = ctx.root; - List boundExpr = bindAndTypeCoercion(project.getProjects(), - ctx.connectContext); - return new LogicalProject<>(boundExpr, project.child(), project.isDistinct()); - }) - ), - RuleType.BINDING_AGGREGATE_FUNCTION.build( - logicalAggregate().thenApply(ctx -> { - LogicalAggregate agg = ctx.root; - List groupBy = bindAndTypeCoercion(agg.getGroupByExpressions(), - ctx.connectContext); - List output = bindAndTypeCoercion(agg.getOutputExpressions(), - ctx.connectContext); - return agg.withGroupByAndOutput(groupBy, output); - }) - ), - RuleType.BINDING_REPEAT_FUNCTION.build( - logicalRepeat().thenApply(ctx -> { - LogicalRepeat repeat = ctx.root; - List> groupingSets = repeat.getGroupingSets() - .stream() - .map(groupingSet -> bindAndTypeCoercion(groupingSet, ctx.connectContext)) - .collect(ImmutableList.toImmutableList()); - List output = bindAndTypeCoercion(repeat.getOutputExpressions(), - ctx.connectContext); - return repeat.withGroupSetsAndOutput(groupingSets, output); - }) - ), - RuleType.BINDING_FILTER_FUNCTION.build( - logicalFilter().thenApply(ctx -> { - LogicalFilter filter = ctx.root; - Set conjuncts = bindAndTypeCoercion(filter.getConjuncts(), ctx.connectContext); - return new LogicalFilter<>(conjuncts, filter.child()); - }) - ), - RuleType.BINDING_HAVING_FUNCTION.build( - logicalHaving().thenApply(ctx -> { - LogicalHaving having = ctx.root; - Set conjuncts = bindAndTypeCoercion(having.getConjuncts(), ctx.connectContext); - return new LogicalHaving<>(conjuncts, having.child()); - }) - ), - RuleType.BINDING_SORT_FUNCTION.build( - logicalSort().thenApply(ctx -> { - LogicalSort sort = ctx.root; - List orderKeys = sort.getOrderKeys().stream() - .map(orderKey -> new OrderKey( - bindAndTypeCoercion(orderKey.getExpr(), - ctx.connectContext.getEnv(), - new ExpressionRewriteContext(ctx.connectContext) - ), - orderKey.isAsc(), - orderKey.isNullFirst()) - - ) - .collect(ImmutableList.toImmutableList()); - return new LogicalSort<>(orderKeys, sort.child()); - }) - ), - RuleType.BINDING_JOIN_FUNCTION.build( - logicalJoin().thenApply(ctx -> { - LogicalJoin join = ctx.root; - List hashConjuncts = bindAndTypeCoercion(join.getHashJoinConjuncts(), - ctx.connectContext); - List otherConjuncts = bindAndTypeCoercion(join.getOtherJoinConjuncts(), - ctx.connectContext); - return new LogicalJoin<>(join.getJoinType(), hashConjuncts, otherConjuncts, - join.getHint(), - join.left(), join.right()); - }) - ), - RuleType.BINDING_UNBOUND_TVF_RELATION_FUNCTION.build( - unboundTVFRelation().thenApply(ctx -> { - UnboundTVFRelation relation = ctx.root; - return FunctionBinder.INSTANCE.bindTableValuedFunction(relation, ctx.statementContext); - }) - ) - ); - } - - private List bindAndTypeCoercion(List exprList, ConnectContext ctx) { - ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx); - return exprList.stream() - .map(expr -> bindAndTypeCoercion(expr, ctx.getEnv(), rewriteContext)) - .collect(ImmutableList.toImmutableList()); - } - - private E bindAndTypeCoercion(E expr, Env env, ExpressionRewriteContext ctx) { - expr = FunctionBinder.INSTANCE.bind(expr, env); - expr = (E) expr.accept(CheckLegalityBeforeTypeCoercion.INSTANCE, ctx); - expr = (E) CharacterLiteralTypeCoercion.INSTANCE.rewrite(expr, ctx); - return (E) TypeCoercion.INSTANCE.rewrite(expr, null); - } - - private Set bindAndTypeCoercion(Set exprSet, ConnectContext ctx) { - ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx); - return exprSet.stream() - .map(expr -> bindAndTypeCoercion(expr, ctx.getEnv(), rewriteContext)) - .collect(ImmutableSet.toImmutableSet()); - } - - /** - * function binder - */ - public static class FunctionBinder extends DefaultExpressionRewriter { - public static final FunctionBinder INSTANCE = new FunctionBinder(); - - public E bind(E expression, Env env) { - return (E) expression.accept(this, env); - } - - /** - * bindTableValuedFunction - */ - public LogicalTVFRelation bindTableValuedFunction(UnboundTVFRelation unboundTVFRelation, - StatementContext statementContext) { - Env env = statementContext.getConnectContext().getEnv(); - FunctionRegistry functionRegistry = env.getFunctionRegistry(); - - String functionName = unboundTVFRelation.getFunctionName(); - TVFProperties arguments = unboundTVFRelation.getProperties(); - FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, arguments); - BoundFunction function = functionBuilder.build(functionName, arguments); - if (!(function instanceof TableValuedFunction)) { - throw new AnalysisException(function.toSql() + " is not a TableValuedFunction"); - } - return new LogicalTVFRelation(unboundTVFRelation.getId(), (TableValuedFunction) function); - } - - /** - * bindTableGeneratingFunction - */ - public BoundFunction bindTableGeneratingFunction(UnboundFunction unboundFunction, - StatementContext statementContext) { - Env env = statementContext.getConnectContext().getEnv(); - List boundArguments = unboundFunction.getArguments().stream() - .map(e -> INSTANCE.bind(e, env)) - .collect(ImmutableList.toImmutableList()); - FunctionRegistry functionRegistry = env.getFunctionRegistry(); - - String functionName = unboundFunction.getName(); - FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, boundArguments); - BoundFunction function = functionBuilder.build(functionName, boundArguments); - if (!(function instanceof TableGeneratingFunction)) { - throw new AnalysisException(function.toSql() + " is not a TableGeneratingFunction"); - } - - return function; - } - - @Override - public Expression visitUnboundFunction(UnboundFunction unboundFunction, Env env) { - unboundFunction = (UnboundFunction) super.visitUnboundFunction(unboundFunction, env); - - FunctionRegistry functionRegistry = env.getFunctionRegistry(); - String functionName = unboundFunction.getName(); - List arguments = unboundFunction.isDistinct() - ? ImmutableList.builder() - .add(unboundFunction.isDistinct()) - .addAll(unboundFunction.getArguments()) - .build() - : (List) unboundFunction.getArguments(); - - FunctionBuilder builder = functionRegistry.findFunctionBuilder(functionName, arguments); - BoundFunction boundFunction = builder.build(functionName, arguments); - return boundFunction; - } - - /** - * gets the method for calculating the time. - * e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB - */ - @Override - public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env context) { - arithmetic = (TimestampArithmetic) super.visitTimestampArithmetic(arithmetic, context); - - String funcOpName; - if (arithmetic.getFuncName() == null) { - // e.g. YEARS_ADD, MONTHS_SUB - funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(), - (arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB"); - } else { - funcOpName = arithmetic.getFuncName(); - } - return arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT)); - } - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FunctionBinder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FunctionBinder.java new file mode 100644 index 0000000000..4cc49fb692 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FunctionBinder.java @@ -0,0 +1,415 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.analysis; + +import org.apache.doris.analysis.ArithmeticExpr.Operator; +import org.apache.doris.catalog.FunctionRegistry; +import org.apache.doris.catalog.PrimitiveType; +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.analyzer.UnboundFunction; +import org.apache.doris.nereids.trees.expressions.BinaryOperator; +import org.apache.doris.nereids.trees.expressions.BitNot; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Divide; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.IntegralDivide; +import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.CharLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; +import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; +import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; +import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.BooleanType; +import org.apache.doris.nereids.types.CharType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DateTimeType; +import org.apache.doris.nereids.types.DateType; +import org.apache.doris.nereids.types.DecimalV2Type; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarcharType; +import org.apache.doris.nereids.types.coercion.AbstractDataType; +import org.apache.doris.nereids.types.coercion.CharacterType; +import org.apache.doris.nereids.types.coercion.FractionalType; +import org.apache.doris.nereids.types.coercion.IntegralType; +import org.apache.doris.nereids.types.coercion.NumericType; +import org.apache.doris.nereids.util.TypeCoercionUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * function binder + */ +class FunctionBinder extends DefaultExpressionRewriter { + public static final FunctionBinder INSTANCE = new FunctionBinder(); + + public E bind(E expression, CascadesContext context) { + return (E) expression.accept(this, context); + } + + @Override + public Expression visit(Expression expr, CascadesContext context) { + expr = super.visit(expr, context); + expr.checkLegalityBeforeTypeCoercion(); + if (expr instanceof ImplicitCastInputTypes) { + List expectedInputTypes = ((ImplicitCastInputTypes) expr).expectedInputTypes(); + if (!expectedInputTypes.isEmpty()) { + return visitImplicitCastInputTypes(expr, expectedInputTypes); + } + } + return expr; + } + + /* ******************************************************************************************** + * bind function + * ******************************************************************************************** */ + + @Override + public Expression visitUnboundFunction(UnboundFunction unboundFunction, CascadesContext context) { + unboundFunction = unboundFunction.withChildren(unboundFunction.children().stream() + .map(e -> e.accept(this, context)).collect(Collectors.toList())); + + // bind function + FunctionRegistry functionRegistry = context.getConnectContext().getEnv().getFunctionRegistry(); + String functionName = unboundFunction.getName(); + List arguments = unboundFunction.isDistinct() + ? ImmutableList.builder() + .add(unboundFunction.isDistinct()) + .addAll(unboundFunction.getArguments()) + .build() + : (List) unboundFunction.getArguments(); + + FunctionBuilder builder = functionRegistry.findFunctionBuilder(functionName, arguments); + BoundFunction boundFunction = builder.build(functionName, arguments); + + // check + boundFunction.checkLegalityBeforeTypeCoercion(); + + // type coercion + return visitImplicitCastInputTypes(boundFunction, boundFunction.expectedInputTypes()); + } + + /** + * gets the method for calculating the time. + * e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB + */ + @Override + public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, CascadesContext context) { + Expression left = arithmetic.left().accept(this, context); + Expression right = arithmetic.right().accept(this, context); + + // bind function + arithmetic = (TimestampArithmetic) arithmetic.withChildren(left, right); + String funcOpName; + if (arithmetic.getFuncName() == null) { + // e.g. YEARS_ADD, MONTHS_SUB + funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(), + (arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB"); + } else { + funcOpName = arithmetic.getFuncName(); + } + arithmetic = (TimestampArithmetic) arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT)); + + // type coercion + return visitImplicitCastInputTypes(arithmetic, arithmetic.expectedInputTypes()); + } + + /* ******************************************************************************************** + * type coercion + * ******************************************************************************************** */ + + @Override + public Expression visitIntegralDivide(IntegralDivide integralDivide, CascadesContext context) { + Expression left = integralDivide.left().accept(this, context); + Expression right = integralDivide.right().accept(this, context); + + // check before bind + integralDivide.checkLegalityBeforeTypeCoercion(); + + // type coercion + Expression newLeft = TypeCoercionUtils.castIfNotSameType(left, BigIntType.INSTANCE); + Expression newRight = TypeCoercionUtils.castIfNotSameType(right, BigIntType.INSTANCE); + return integralDivide.withChildren(newLeft, newRight); + } + + @Override + public Expression visitBinaryOperator(BinaryOperator binaryOperator, CascadesContext context) { + Expression left = binaryOperator.left().accept(this, context); + Expression right = binaryOperator.right().accept(this, context); + // check + binaryOperator.checkLegalityBeforeTypeCoercion(); + + // characterLiteralTypeCoercion + if (left instanceof Literal || right instanceof Literal) { + if (left instanceof Literal && ((Literal) left).isCharacterLiteral()) { + left = characterLiteralTypeCoercion(((Literal) left).getStringValue(), right.getDataType()) + .orElse(left); + } + if (right instanceof Literal && ((Literal) right).isCharacterLiteral()) { + right = characterLiteralTypeCoercion(((Literal) right).getStringValue(), left.getDataType()) + .orElse(right); + } + } + binaryOperator = (BinaryOperator) binaryOperator.withChildren(left, right); + + // type coercion + if (binaryOperator instanceof ImplicitCastInputTypes) { + List expectedInputTypes = ((ImplicitCastInputTypes) binaryOperator).expectedInputTypes(); + if (!expectedInputTypes.isEmpty()) { + binaryOperator.children().stream().filter(e -> e instanceof StringLikeLiteral) + .forEach(expr -> { + try { + new BigDecimal(((StringLikeLiteral) expr).getStringValue()); + } catch (NumberFormatException e) { + throw new IllegalStateException(String.format( + "string literal %s cannot be cast to double", expr.toSql())); + } + }); + binaryOperator = (BinaryOperator) visitImplicitCastInputTypes(binaryOperator, expectedInputTypes); + } + } + + BinaryOperator op = binaryOperator; + Expression opLeft = op.left(); + Expression opRight = op.right(); + + return Optional.of(TypeCoercionUtils.canHandleTypeCoercion(left.getDataType(), right.getDataType())) + .filter(Boolean::booleanValue) + .map(b -> TypeCoercionUtils.findTightestCommonType(op, opLeft.getDataType(), opRight.getDataType())) + .filter(Optional::isPresent) + .map(Optional::get) + .filter(ct -> op.inputType().acceptsType(ct)) + .filter(ct -> !opLeft.getDataType().equals(ct) || !opRight.getDataType().equals(ct)) + .map(commonType -> { + Expression newLeft = TypeCoercionUtils.castIfNotSameType(opLeft, commonType); + Expression newRight = TypeCoercionUtils.castIfNotSameType(opRight, commonType); + return op.withChildren(newLeft, newRight); + }) + .orElse(op.withChildren(opLeft, opRight)); + } + + @Override + public Expression visitDivide(Divide divide, CascadesContext context) { + Expression left = divide.left().accept(this, context); + Expression right = divide.right().accept(this, context); + divide = (Divide) divide.withChildren(left, right); + + // check + divide.checkLegalityBeforeTypeCoercion(); + + // type coercion + DataType t1 = TypeCoercionUtils.getNumResultType(left.getDataType()); + DataType t2 = TypeCoercionUtils.getNumResultType(right.getDataType()); + DataType commonType = TypeCoercionUtils.findCommonNumericsType(t1, t2); + + if (divide.getLegacyOperator() == Operator.DIVIDE + && (commonType.isBigIntType() || commonType.isLargeIntType())) { + commonType = DoubleType.INSTANCE; + } + Expression newLeft = TypeCoercionUtils.castIfNotSameType(left, commonType); + Expression newRight = TypeCoercionUtils.castIfNotSameType(right, commonType); + return divide.withChildren(newLeft, newRight); + } + + @Override + public Expression visitCaseWhen(CaseWhen caseWhen, CascadesContext context) { + List rewrittenChildren = caseWhen.children().stream() + .map(e -> e.accept(this, context)).collect(Collectors.toList()); + CaseWhen newCaseWhen = caseWhen.withChildren(rewrittenChildren); + + // check + newCaseWhen.checkLegalityBeforeTypeCoercion(); + + // type coercion + List dataTypesForCoercion = newCaseWhen.dataTypesForCoercion(); + if (dataTypesForCoercion.size() <= 1) { + return newCaseWhen; + } + DataType first = dataTypesForCoercion.get(0); + if (dataTypesForCoercion.stream().allMatch(dataType -> dataType.equals(first))) { + return newCaseWhen; + } + Optional optionalCommonType = TypeCoercionUtils.findWiderCommonType(dataTypesForCoercion); + return optionalCommonType + .map(commonType -> { + List newChildren + = newCaseWhen.getWhenClauses().stream() + .map(wc -> wc.withChildren(wc.getOperand(), + TypeCoercionUtils.castIfNotSameType(wc.getResult(), commonType))) + .collect(Collectors.toList()); + newCaseWhen.getDefaultValue() + .map(dv -> TypeCoercionUtils.castIfNotSameType(dv, commonType)) + .ifPresent(newChildren::add); + return newCaseWhen.withChildren(newChildren); + }) + .orElse(newCaseWhen); + } + + @Override + public Expression visitInPredicate(InPredicate inPredicate, CascadesContext context) { + List rewrittenChildren = inPredicate.children().stream() + .map(e -> e.accept(this, context)).collect(Collectors.toList()); + InPredicate newInPredicate = inPredicate.withChildren(rewrittenChildren); + + // check + newInPredicate.checkLegalityBeforeTypeCoercion(); + + // type coercion + if (newInPredicate.getOptions().stream().map(Expression::getDataType) + .allMatch(dt -> dt.equals(newInPredicate.getCompareExpr().getDataType()))) { + return newInPredicate; + } + Optional optionalCommonType = TypeCoercionUtils.findWiderCommonType(newInPredicate.children() + .stream().map(Expression::getDataType).collect(Collectors.toList())); + + return optionalCommonType + .map(commonType -> { + List newChildren = newInPredicate.children().stream() + .map(e -> TypeCoercionUtils.castIfNotSameType(e, commonType)) + .collect(Collectors.toList()); + return newInPredicate.withChildren(newChildren); + }) + .orElse(newInPredicate); + } + + @Override + public Expression visitBitNot(BitNot bitNot, CascadesContext context) { + Expression child = bitNot.child().accept(this, context); + // check + bitNot.checkLegalityBeforeTypeCoercion(); + + // type coercion + if (child.getDataType().toCatalogDataType().getPrimitiveType().ordinal() > PrimitiveType.LARGEINT.ordinal()) { + child = new Cast(child, BigIntType.INSTANCE); + } + return bitNot.withChildren(child); + } + + private Optional characterLiteralTypeCoercion(String value, DataType dataType) { + Expression ret = null; + try { + if (dataType instanceof BooleanType) { + if ("true".equalsIgnoreCase(value)) { + ret = BooleanLiteral.TRUE; + } + if ("false".equalsIgnoreCase(value)) { + ret = BooleanLiteral.FALSE; + } + } else if (dataType instanceof IntegralType) { + BigInteger bigInt = new BigInteger(value); + if (BigInteger.valueOf(bigInt.byteValue()).equals(bigInt)) { + ret = new TinyIntLiteral(bigInt.byteValue()); + } else if (BigInteger.valueOf(bigInt.shortValue()).equals(bigInt)) { + ret = new SmallIntLiteral(bigInt.shortValue()); + } else if (BigInteger.valueOf(bigInt.intValue()).equals(bigInt)) { + ret = new IntegerLiteral(bigInt.intValue()); + } else if (BigInteger.valueOf(bigInt.longValue()).equals(bigInt)) { + ret = new BigIntLiteral(bigInt.longValueExact()); + } else { + ret = new LargeIntLiteral(bigInt); + } + } else if (dataType instanceof FloatType) { + ret = new FloatLiteral(Float.parseFloat(value)); + } else if (dataType instanceof DoubleType) { + ret = new DoubleLiteral(Double.parseDouble(value)); + } else if (dataType instanceof DecimalV2Type) { + ret = new DecimalLiteral(new BigDecimal(value)); + } else if (dataType instanceof CharType) { + ret = new CharLiteral(value, value.length()); + } else if (dataType instanceof VarcharType) { + ret = new VarcharLiteral(value, value.length()); + } else if (dataType instanceof StringType) { + ret = new StringLiteral(value); + } else if (dataType instanceof DateType) { + ret = new DateLiteral(value); + } else if (dataType instanceof DateTimeType) { + ret = new DateTimeLiteral(value); + } + } catch (Exception e) { + // ignore + } + return Optional.ofNullable(ret); + } + + private Expression visitImplicitCastInputTypes(Expression expr, List expectedInputTypes) { + List> inputImplicitCastTypes + = getInputImplicitCastTypes(expr.children(), expectedInputTypes); + return castInputs(expr, inputImplicitCastTypes); + } + + private List> getInputImplicitCastTypes( + List inputs, List expectedTypes) { + Builder> implicitCastTypes = ImmutableList.builder(); + for (int i = 0; i < inputs.size(); i++) { + DataType argType = inputs.get(i).getDataType(); + AbstractDataType expectedType = expectedTypes.get(i); + Optional castType = TypeCoercionUtils.implicitCast(argType, expectedType); + // TODO: complete the cast logic like FunctionCallExpr.analyzeImpl + boolean legacyCastCompatible = expectedType instanceof DataType + && !(expectedType.getClass().equals(NumericType.class)) + && !(expectedType.getClass().equals(IntegralType.class)) + && !(expectedType.getClass().equals(FractionalType.class)) + && !(expectedType.getClass().equals(CharacterType.class)) + && !argType.toCatalogDataType().matchesType(expectedType.toCatalogDataType()); + if (!castType.isPresent() && legacyCastCompatible) { + castType = Optional.of((DataType) expectedType); + } + implicitCastTypes.add(castType); + } + return implicitCastTypes.build(); + } + + private Expression castInputs(Expression expr, List> castTypes) { + return expr.withChildren((child, childIndex) -> { + DataType argType = child.getDataType(); + Optional castType = castTypes.get(childIndex); + if (castType.isPresent() && !castType.get().equals(argType)) { + return TypeCoercionUtils.castIfNotSameType(child, castType.get()); + } else { + return child; + } + }); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/Binder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SlotBinder.java similarity index 96% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/Binder.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SlotBinder.java index b6f5be8568..f7bc29f90f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/Binder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SlotBinder.java @@ -29,7 +29,6 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.planner.PlannerContext; import org.apache.commons.lang.StringUtils; @@ -37,9 +36,9 @@ import java.util.List; import java.util.Optional; import java.util.stream.Collectors; -class Binder extends SubExprAnalyzer { +class SlotBinder extends SubExprAnalyzer { - public Binder(Scope scope, CascadesContext cascadesContext) { + public SlotBinder(Scope scope, CascadesContext cascadesContext) { super(scope, cascadesContext); } @@ -48,7 +47,7 @@ class Binder extends SubExprAnalyzer { } @Override - public Expression visitUnboundAlias(UnboundAlias unboundAlias, PlannerContext context) { + public Expression visitUnboundAlias(UnboundAlias unboundAlias, CascadesContext context) { Expression child = unboundAlias.child().accept(this, context); if (unboundAlias.getAlias().isPresent()) { return new Alias(child, unboundAlias.getAlias().get()); @@ -62,7 +61,7 @@ class Binder extends SubExprAnalyzer { } @Override - public Slot visitUnboundSlot(UnboundSlot unboundSlot, PlannerContext context) { + public Slot visitUnboundSlot(UnboundSlot unboundSlot, CascadesContext context) { Optional> boundedOpt = Optional.of(bindSlot(unboundSlot, getScope().getSlots())); boolean foundInThisScope = !boundedOpt.get().isEmpty(); // Currently only looking for symbols on the previous level. @@ -110,7 +109,7 @@ class Binder extends SubExprAnalyzer { } @Override - public Expression visitUnboundStar(UnboundStar unboundStar, PlannerContext context) { + public Expression visitUnboundStar(UnboundStar unboundStar, CascadesContext context) { List qualifier = unboundStar.getQualifier(); List slots = getScope().getSlots() .stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubExprAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubExprAnalyzer.java index a4b57fa851..7904c77e9d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubExprAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubExprAnalyzer.java @@ -33,7 +33,6 @@ import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewri import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.planner.PlannerContext; import com.google.common.collect.ImmutableList; @@ -45,7 +44,7 @@ import java.util.Optional; /** * Use the visitor to iterate sub expression. */ -public class SubExprAnalyzer extends DefaultExpressionRewriter { +class SubExprAnalyzer extends DefaultExpressionRewriter { private final Scope scope; private final CascadesContext cascadesContext; @@ -56,7 +55,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter { } @Override - public Expression visitNot(Not not, PlannerContext context) { + public Expression visitNot(Not not, CascadesContext context) { Expression child = not.child(); if (child instanceof Exists) { return visitExistsSubquery( @@ -69,7 +68,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter { } @Override - public Expression visitExistsSubquery(Exists exists, PlannerContext context) { + public Expression visitExistsSubquery(Exists exists, CascadesContext context) { AnalyzedResult analyzedResult = analyzeSubquery(exists); return new Exists(analyzedResult.getLogicalPlan(), @@ -77,7 +76,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter { } @Override - public Expression visitInSubquery(InSubquery expr, PlannerContext context) { + public Expression visitInSubquery(InSubquery expr, CascadesContext context) { AnalyzedResult analyzedResult = analyzeSubquery(expr); checkOutputColumn(analyzedResult.getLogicalPlan()); @@ -90,7 +89,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter { } @Override - public Expression visitScalarSubquery(ScalarSubquery scalar, PlannerContext context) { + public Expression visitScalarSubquery(ScalarSubquery scalar, CascadesContext context) { AnalyzedResult analyzedResult = analyzeSubquery(scalar); checkOutputColumn(analyzedResult.getLogicalPlan()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionNormalization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionNormalization.java index 575fbbf600..84fcdfb687 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionNormalization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionNormalization.java @@ -45,6 +45,7 @@ public class ExpressionNormalization extends ExpressionRewrite { BetweenToCompoundRule.INSTANCE, InPredicateToEqualToRule.INSTANCE, SimplifyNotExprRule.INSTANCE, + // TODO(morrySnow): remove type coercion from here after we could process subquery type coercion when bind CharacterLiteralTypeCoercion.INSTANCE, SimplifyArithmeticRule.INSTANCE, TypeCoercion.INSTANCE, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/CharacterLiteralTypeCoercion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/CharacterLiteralTypeCoercion.java index e314bfbd5c..9bcb0ec783 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/CharacterLiteralTypeCoercion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/CharacterLiteralTypeCoercion.java @@ -57,6 +57,7 @@ import java.util.Optional; /** * coercion character literal to another side */ +@Deprecated @DependsRules(CheckLegalityBeforeTypeCoercion.class) public class CharacterLiteralTypeCoercion extends AbstractExpressionRewriteRule { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/TypeCoercion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/TypeCoercion.java index 2cf38efb1b..8e4a793068 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/TypeCoercion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/TypeCoercion.java @@ -20,7 +20,6 @@ package org.apache.doris.nereids.rules.expression.rewrite.rules; import org.apache.doris.analysis.ArithmeticExpr.Operator; import org.apache.doris.catalog.PrimitiveType; import org.apache.doris.nereids.annotation.DependsRules; -import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.jobs.batch.CheckLegalityBeforeTypeCoercion; import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule; import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext; @@ -56,7 +55,7 @@ import java.util.stream.Collectors; * a rule to add implicit cast for expressions. * This class is inspired by spark's TypeCoercion. */ -@Developing +@Deprecated @DependsRules(CheckLegalityBeforeTypeCoercion.class) public class TypeCoercion extends AbstractExpressionRewriteRule { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSchemaScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSchemaScan.java index c373f66aa2..49dcb47175 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSchemaScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSchemaScan.java @@ -21,8 +21,6 @@ import org.apache.doris.catalog.SchemaTable; import org.apache.doris.catalog.TableIf; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.LogicalProperties; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.RelationId; @@ -30,8 +28,6 @@ import org.apache.doris.nereids.trees.plans.algebra.Scan; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.Utils; -import com.google.common.collect.ImmutableList; - import java.util.List; import java.util.Optional; @@ -59,14 +55,6 @@ public class LogicalSchemaScan extends LogicalRelation implements Scan { return visitor.visitLogicalSchemaScan(this, context); } - @Override - public List computeNonUserVisibleOutput() { - SchemaTable schemaTable = getTable(); - return schemaTable.getBaseSchema().stream() - .map(col -> SlotReference.fromColumn(col, qualified())) - .collect(ImmutableList.toImmutableList()); - } - @Override public Plan withGroupExpression(Optional groupExpression) { return new LogicalSchemaScan(id, table, qualifier, groupExpression, Optional.of(getLogicalProperties())); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java index 65fbe31049..3ecbd52a46 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java @@ -314,7 +314,7 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt ExceptionChecker.expectThrowsWithMsg( AnalysisException.class, "Aggregate functions in having clause can't be nested:" - + " sum(cast((cast(a1 as DOUBLE) + avg(a2)) as SMALLINT)).", + + " sum((cast(a1 as DOUBLE) + avg(a2))).", () -> PlanChecker.from(connectContext).analyze( "SELECT a1 FROM t1 GROUP BY a1 HAVING SUM(a1 + AVG(a2)) > 0" )); @@ -322,7 +322,7 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt ExceptionChecker.expectThrowsWithMsg( AnalysisException.class, "Aggregate functions in having clause can't be nested:" - + " sum(cast((cast((a1 + a2) as DOUBLE) + avg(a2)) as INT)).", + + " sum((cast((a1 + a2) as DOUBLE) + avg(a2))).", () -> PlanChecker.from(connectContext).analyze( "SELECT a1 FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + AVG(a2)) > 0" ));