[refactor] bind slot and function in one rule (#16288)
1. use one rule to bind slot and function and do type coercion to fix type and nullable error a. SUM(a1 + AVG(a2)) when a1 and a2 are TINYINT. Before, the return type was SMALLINT, after this PR will return the right type - DOUBLE. 2. fix runtime filter gnerator bugs - bind runtime filter on wrong join conjuncts.
This commit is contained in:
@ -22,7 +22,6 @@ import org.apache.doris.nereids.jobs.batch.AdjustAggregateNullableForEmptySetJob
|
||||
import org.apache.doris.nereids.jobs.batch.AnalyzeRulesJob;
|
||||
import org.apache.doris.nereids.jobs.batch.AnalyzeSubqueryRulesJob;
|
||||
import org.apache.doris.nereids.jobs.batch.CheckAnalysisJob;
|
||||
import org.apache.doris.nereids.jobs.batch.TypeCoercionJob;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
@ -52,7 +51,6 @@ public class NereidsAnalyzer {
|
||||
new AnalyzeRulesJob(cascadesContext, outerScope).execute();
|
||||
new AnalyzeSubqueryRulesJob(cascadesContext).execute();
|
||||
new AdjustAggregateNullableForEmptySetJob(cascadesContext).execute();
|
||||
new TypeCoercionJob(cascadesContext).execute();
|
||||
// check whether analyze result is meaningful
|
||||
new CheckAnalysisJob(cascadesContext).execute();
|
||||
}
|
||||
|
||||
@ -20,9 +20,8 @@ package org.apache.doris.nereids.jobs.batch;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.analyzer.Scope;
|
||||
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
|
||||
import org.apache.doris.nereids.rules.analysis.BindFunction;
|
||||
import org.apache.doris.nereids.rules.analysis.BindExpression;
|
||||
import org.apache.doris.nereids.rules.analysis.BindRelation;
|
||||
import org.apache.doris.nereids.rules.analysis.BindSlotReference;
|
||||
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
|
||||
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
|
||||
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
|
||||
@ -32,9 +31,6 @@ import org.apache.doris.nereids.rules.analysis.RegisterCTE;
|
||||
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
|
||||
import org.apache.doris.nereids.rules.analysis.ResolveOrdinalInOrderByAndGroupBy;
|
||||
import org.apache.doris.nereids.rules.analysis.UserAuthentication;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.CharacterLiteralTypeCoercion;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.HideOneRowRelationUnderUnion;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -61,8 +57,7 @@ public class AnalyzeRulesJob extends BatchRulesJob {
|
||||
new BindRelation(),
|
||||
new CheckPolicy(),
|
||||
new UserAuthentication(),
|
||||
new BindSlotReference(scope),
|
||||
new BindFunction(),
|
||||
new BindExpression(scope),
|
||||
new ProjectToGlobalAggregate(),
|
||||
// this rule check's the logicalProject node's isDisinct property
|
||||
// and replace the logicalProject node with a LogicalAggregate node
|
||||
@ -73,9 +68,7 @@ public class AnalyzeRulesJob extends BatchRulesJob {
|
||||
new AvgDistinctToSumDivCount(),
|
||||
new ResolveOrdinalInOrderByAndGroupBy(),
|
||||
new ReplaceExpressionByChildOutput(),
|
||||
new HideOneRowRelationUnderUnion(),
|
||||
new ExpressionNormalization(cascadesContext.getConnectContext(),
|
||||
ImmutableList.of(CharacterLiteralTypeCoercion.INSTANCE, TypeCoercion.INSTANCE))
|
||||
new HideOneRowRelationUnderUnion()
|
||||
),
|
||||
topDownBatch(
|
||||
new FillUpMissingSlots(),
|
||||
|
||||
@ -45,7 +45,6 @@ import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
@ -89,31 +88,27 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
|
||||
List<TRuntimeFilterType> legalTypes = Arrays.stream(TRuntimeFilterType.values())
|
||||
.filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0)
|
||||
.collect(Collectors.toList());
|
||||
AtomicInteger cnt = new AtomicInteger();
|
||||
join.getHashJoinConjuncts().stream()
|
||||
.map(EqualTo.class::cast)
|
||||
// TODO: some complex situation cannot be handled now, see testPushDownThroughJoin.
|
||||
// TODO: we will support it in later version.
|
||||
.forEach(expr -> legalTypes.forEach(type -> {
|
||||
Pair<Expression, Expression> normalizedChildren = checkAndMaybeSwapChild(expr, join);
|
||||
// aliasTransMap doesn't contain the key, means that the path from the olap scan to the join
|
||||
// contains join with denied join type. for example: a left join b on a.id = b.id
|
||||
if (normalizedChildren == null
|
||||
|| !aliasTransferMap.containsKey((Slot) normalizedChildren.first)) {
|
||||
return;
|
||||
}
|
||||
Pair<Slot, Slot> slots = Pair.of(
|
||||
aliasTransferMap.get((Slot) normalizedChildren.first).second.toSlot(),
|
||||
((Slot) normalizedChildren.second));
|
||||
RuntimeFilter filter = new RuntimeFilter(generator.getNextId(),
|
||||
slots.second, slots.first, type,
|
||||
cnt.getAndIncrement(), join);
|
||||
ctx.addJoinToTargetMap(join, slots.first.getExprId());
|
||||
ctx.setTargetExprIdToFilter(slots.first.getExprId(), filter);
|
||||
ctx.setTargetsOnScanNode(
|
||||
aliasTransferMap.get((Slot) normalizedChildren.first).first,
|
||||
slots.first);
|
||||
}));
|
||||
// TODO: some complex situation cannot be handled now, see testPushDownThroughJoin.
|
||||
// TODO: we will support it in later version.
|
||||
for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) {
|
||||
EqualTo expr = (EqualTo) join.getHashJoinConjuncts().get(i);
|
||||
for (TRuntimeFilterType type : legalTypes) {
|
||||
Pair<Expression, Expression> normalizedChildren = checkAndMaybeSwapChild(expr, join);
|
||||
// aliasTransMap doesn't contain the key, means that the path from the olap scan to the join
|
||||
// contains join with denied join type. for example: a left join b on a.id = b.id
|
||||
if (normalizedChildren == null || !aliasTransferMap.containsKey((Slot) normalizedChildren.first)) {
|
||||
continue;
|
||||
}
|
||||
Pair<Slot, Slot> slots = Pair.of(
|
||||
aliasTransferMap.get((Slot) normalizedChildren.first).second.toSlot(),
|
||||
((Slot) normalizedChildren.second));
|
||||
RuntimeFilter filter = new RuntimeFilter(generator.getNextId(),
|
||||
slots.second, slots.first, type, i, join);
|
||||
ctx.addJoinToTargetMap(join, slots.first.getExprId());
|
||||
ctx.setTargetExprIdToFilter(slots.first.getExprId(), filter);
|
||||
ctx.setTargetsOnScanNode(aliasTransferMap.get((Slot) normalizedChildren.first).first, slots.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
return join;
|
||||
}
|
||||
|
||||
@ -17,17 +17,20 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.catalog.BuiltinAggregateFunctions;
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.catalog.FunctionRegistry;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
import org.apache.doris.nereids.analyzer.Scope;
|
||||
import org.apache.doris.nereids.analyzer.UnboundFunction;
|
||||
import org.apache.doris.nereids.analyzer.UnboundOneRowRelation;
|
||||
import org.apache.doris.nereids.analyzer.UnboundSlot;
|
||||
import org.apache.doris.nereids.analyzer.UnboundTVFRelation;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.properties.OrderKey;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.analysis.BindFunction.FunctionBinder;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
|
||||
import org.apache.doris.nereids.trees.UnaryNode;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.BoundStar;
|
||||
@ -36,7 +39,13 @@ import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.TVFProperties;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Function;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
@ -57,6 +66,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.UsingJoin;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
@ -81,11 +91,11 @@ import java.util.stream.Stream;
|
||||
* BindSlotReference.
|
||||
*/
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
|
||||
public class BindSlotReference implements AnalysisRuleFactory {
|
||||
public class BindExpression implements AnalysisRuleFactory {
|
||||
|
||||
private final Optional<Scope> outerScope;
|
||||
|
||||
public BindSlotReference(Optional<Scope> outerScope) {
|
||||
public BindExpression(Optional<Scope> outerScope) {
|
||||
this.outerScope = Objects.requireNonNull(outerScope, "outerScope cannot be null");
|
||||
}
|
||||
|
||||
@ -104,18 +114,23 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
logicalProject().when(Plan::canBind).thenApply(ctx -> {
|
||||
LogicalProject<GroupPlan> project = ctx.root;
|
||||
List<NamedExpression> boundProjections =
|
||||
bind(project.getProjects(), project.children(), ctx.cascadesContext);
|
||||
List<NamedExpression> boundExceptions = bind(project.getExcepts(), project.children(),
|
||||
bindSlot(project.getProjects(), project.children(), ctx.cascadesContext);
|
||||
List<NamedExpression> boundExceptions = bindSlot(project.getExcepts(), project.children(),
|
||||
ctx.cascadesContext);
|
||||
boundProjections = flatBoundStar(boundProjections, boundExceptions);
|
||||
boundProjections = boundProjections.stream()
|
||||
.map(expr -> bindFunction(expr, ctx.cascadesContext))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
return new LogicalProject<>(boundProjections, project.child(), project.isDistinct());
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_FILTER_SLOT.build(
|
||||
logicalFilter().when(Plan::canBind).thenApply(ctx -> {
|
||||
LogicalFilter<GroupPlan> filter = ctx.root;
|
||||
Set<Expression> boundConjuncts
|
||||
= bind(filter.getConjuncts(), filter.children(), ctx.cascadesContext);
|
||||
Set<Expression> boundConjuncts = filter.getConjuncts().stream()
|
||||
.map(expr -> bindSlot(expr, filter.children(), ctx.cascadesContext))
|
||||
.map(expr -> bindFunction(expr, ctx.cascadesContext))
|
||||
.collect(Collectors.toSet());
|
||||
return new LogicalFilter<>(boundConjuncts, filter.child());
|
||||
})
|
||||
),
|
||||
@ -141,7 +156,7 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
.peek(s -> slotNames.add(s.getName()))
|
||||
.collect(Collectors.toList()));
|
||||
for (Expression unboundSlot : unboundSlots) {
|
||||
Expression expression = new Binder(scope, ctx.cascadesContext).bind(unboundSlot);
|
||||
Expression expression = new SlotBinder(scope, ctx.cascadesContext).bind(unboundSlot);
|
||||
leftSlots.add(expression);
|
||||
}
|
||||
slotNames.clear();
|
||||
@ -151,7 +166,7 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
.collect(Collectors.toList()));
|
||||
List<Expression> rightSlots = new ArrayList<>();
|
||||
for (Expression unboundSlot : unboundSlots) {
|
||||
Expression expression = new Binder(scope, ctx.cascadesContext).bind(unboundSlot);
|
||||
Expression expression = new SlotBinder(scope, ctx.cascadesContext).bind(unboundSlot);
|
||||
rightSlots.add(expression);
|
||||
}
|
||||
int size = leftSlots.size();
|
||||
@ -166,10 +181,12 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
logicalJoin().when(Plan::canBind).thenApply(ctx -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = ctx.root;
|
||||
List<Expression> cond = join.getOtherJoinConjuncts().stream()
|
||||
.map(expr -> bind(expr, join.children(), ctx.cascadesContext))
|
||||
.map(expr -> bindSlot(expr, join.children(), ctx.cascadesContext))
|
||||
.map(expr -> bindFunction(expr, ctx.cascadesContext))
|
||||
.collect(Collectors.toList());
|
||||
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts().stream()
|
||||
.map(expr -> bind(expr, join.children(), ctx.cascadesContext))
|
||||
.map(expr -> bindSlot(expr, join.children(), ctx.cascadesContext))
|
||||
.map(expr -> bindFunction(expr, ctx.cascadesContext))
|
||||
.collect(Collectors.toList());
|
||||
return new LogicalJoin<>(join.getJoinType(),
|
||||
hashJoinConjuncts, cond, join.getHint(), join.left(), join.right());
|
||||
@ -178,8 +195,10 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
RuleType.BINDING_AGGREGATE_SLOT.build(
|
||||
logicalAggregate().when(Plan::canBind).thenApply(ctx -> {
|
||||
LogicalAggregate<GroupPlan> agg = ctx.root;
|
||||
List<NamedExpression> output =
|
||||
bind(agg.getOutputExpressions(), agg.children(), ctx.cascadesContext);
|
||||
List<NamedExpression> output = agg.getOutputExpressions().stream()
|
||||
.map(expr -> bindSlot(expr, agg.children(), ctx.cascadesContext))
|
||||
.map(expr -> bindFunction(expr, ctx.cascadesContext))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
|
||||
// The columns referenced in group by are first obtained from the child's output,
|
||||
// and then from the node's output
|
||||
@ -235,17 +254,8 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
.filter(ne -> ne instanceof Alias)
|
||||
.map(Alias.class::cast)
|
||||
// agg function cannot be bound with group_by_key
|
||||
// TODO(morrySnow): after bind function moved here,
|
||||
// we could just use instanceof AggregateFunction
|
||||
.filter(alias -> !alias.child().anyMatch(expr -> {
|
||||
if (expr instanceof UnboundFunction) {
|
||||
UnboundFunction unboundFunction = (UnboundFunction) expr;
|
||||
return BuiltinAggregateFunctions.INSTANCE.aggFuncNames.contains(
|
||||
unboundFunction.getName().toLowerCase());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
))
|
||||
.filter(alias -> !alias.child()
|
||||
.anyMatch(expr -> expr instanceof AggregateFunction))
|
||||
.forEach(alias -> childOutputsToExpr.putIfAbsent(alias.getName(), alias.child()));
|
||||
|
||||
List<Expression> replacedGroupBy = agg.getGroupByExpressions().stream()
|
||||
@ -280,9 +290,10 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
boundSlots.addAll(outputSlots);
|
||||
Binder binder = new Binder(toScope(Lists.newArrayList(boundSlots)), ctx.cascadesContext);
|
||||
Binder childBinder
|
||||
= new Binder(toScope(new ArrayList<>(agg.child().getOutputSet())), ctx.cascadesContext);
|
||||
SlotBinder binder = new SlotBinder(
|
||||
toScope(Lists.newArrayList(boundSlots)), ctx.cascadesContext);
|
||||
SlotBinder childBinder = new SlotBinder(
|
||||
toScope(new ArrayList<>(agg.child().getOutputSet())), ctx.cascadesContext);
|
||||
|
||||
List<Expression> groupBy = replacedGroupBy.stream()
|
||||
.map(expression -> {
|
||||
@ -305,15 +316,19 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
if (hasUnBound.test(groupBy)) {
|
||||
throw new AnalysisException("cannot bind GROUP BY KEY: " + unboundGroupBys.get(0).toSql());
|
||||
}
|
||||
groupBy = groupBy.stream()
|
||||
.map(expr -> bindFunction(expr, ctx.cascadesContext))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
return agg.withGroupByAndOutput(groupBy, output);
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_REPEAT_SLOT.build(
|
||||
logicalRepeat().when(Plan::canBind).thenApply(ctx -> {
|
||||
LogicalRepeat<GroupPlan> repeat = ctx.root;
|
||||
|
||||
List<NamedExpression> output =
|
||||
bind(repeat.getOutputExpressions(), repeat.children(), ctx.cascadesContext);
|
||||
List<NamedExpression> output = repeat.getOutputExpressions().stream()
|
||||
.map(expr -> bindSlot(expr, repeat.children(), ctx.cascadesContext))
|
||||
.map(expr -> bindFunction(expr, ctx.cascadesContext))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
|
||||
// The columns referenced in group by are first obtained from the child's output,
|
||||
// and then from the node's output
|
||||
@ -343,7 +358,10 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
|
||||
List<List<Expression>> groupingSets = replacedGroupingSets
|
||||
.stream()
|
||||
.map(groupingSet -> bind(groupingSet, repeat.children(), ctx.cascadesContext))
|
||||
.map(groupingSet -> groupingSet.stream()
|
||||
.map(expr -> bindSlot(expr, repeat.children(), ctx.cascadesContext))
|
||||
.map(expr -> bindFunction(expr, ctx.cascadesContext))
|
||||
.collect(ImmutableList.toImmutableList()))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
List<NamedExpression> newOutput = adjustNullableForRepeat(groupingSets, output);
|
||||
return repeat.withGroupSetsAndOutput(groupingSets, newOutput);
|
||||
@ -383,7 +401,8 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
List<OrderKey> sortItemList = sort.getOrderKeys()
|
||||
.stream()
|
||||
.map(orderKey -> {
|
||||
Expression item = bind(orderKey.getExpr(), sort.child(), ctx.cascadesContext);
|
||||
Expression item = bindSlot(orderKey.getExpr(), sort.child(), ctx.cascadesContext);
|
||||
item = bindFunction(item, ctx.cascadesContext);
|
||||
return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst());
|
||||
}).collect(Collectors.toList());
|
||||
return new LogicalSort<>(sortItemList, sort.child());
|
||||
@ -393,11 +412,12 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
logicalHaving(aggregate()).when(Plan::canBind).thenApply(ctx -> {
|
||||
LogicalHaving<Aggregate<GroupPlan>> having = ctx.root;
|
||||
Plan childPlan = having.child();
|
||||
Set<Expression> boundConjuncts = having.getConjuncts().stream().map(
|
||||
expr -> {
|
||||
expr = bind(expr, childPlan.children(), ctx.cascadesContext);
|
||||
return bind(expr, childPlan, ctx.cascadesContext);
|
||||
Set<Expression> boundConjuncts = having.getConjuncts().stream()
|
||||
.map(expr -> {
|
||||
expr = bindSlot(expr, childPlan.children(), ctx.cascadesContext);
|
||||
return bindSlot(expr, childPlan, ctx.cascadesContext);
|
||||
})
|
||||
.map(expr -> bindFunction(expr, ctx.cascadesContext))
|
||||
.collect(Collectors.toSet());
|
||||
return new LogicalHaving<>(boundConjuncts, having.child());
|
||||
})
|
||||
@ -406,11 +426,12 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
logicalHaving(any()).when(Plan::canBind).thenApply(ctx -> {
|
||||
LogicalHaving<Plan> having = ctx.root;
|
||||
Plan childPlan = having.child();
|
||||
Set<Expression> boundConjuncts = having.getConjuncts().stream().map(
|
||||
expr -> {
|
||||
expr = bind(expr, childPlan, ctx.cascadesContext);
|
||||
return bind(expr, childPlan.children(), ctx.cascadesContext);
|
||||
})
|
||||
Set<Expression> boundConjuncts = having.getConjuncts().stream()
|
||||
.map(expr -> {
|
||||
expr = bindSlot(expr, childPlan, ctx.cascadesContext);
|
||||
return bindSlot(expr, childPlan.children(), ctx.cascadesContext);
|
||||
})
|
||||
.map(expr -> bindFunction(expr, ctx.cascadesContext))
|
||||
.collect(Collectors.toSet());
|
||||
return new LogicalHaving<>(boundConjuncts, having.child());
|
||||
})
|
||||
@ -421,7 +442,8 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
UnboundOneRowRelation oneRowRelation = ctx.root;
|
||||
List<NamedExpression> projects = oneRowRelation.getProjects()
|
||||
.stream()
|
||||
.map(project -> bind(project, ImmutableList.of(), ctx.cascadesContext))
|
||||
.map(project -> bindSlot(project, ImmutableList.of(), ctx.cascadesContext))
|
||||
.map(project -> bindFunction(project, ctx.cascadesContext))
|
||||
.collect(Collectors.toList());
|
||||
return new LogicalOneRowRelation(projects);
|
||||
})
|
||||
@ -450,26 +472,31 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_GENERATE_SLOT.build(
|
||||
logicalGenerate().when(Plan::canBind).thenApply(ctx -> {
|
||||
LogicalGenerate<GroupPlan> generate = ctx.root;
|
||||
List<Function> boundSlotGenerators
|
||||
= bind(generate.getGenerators(), generate.children(), ctx.cascadesContext);
|
||||
List<Function> boundFunctionGenerators = boundSlotGenerators.stream()
|
||||
.map(f -> FunctionBinder.INSTANCE.bindTableGeneratingFunction(
|
||||
(UnboundFunction) f, ctx.statementContext))
|
||||
.collect(Collectors.toList());
|
||||
ImmutableList.Builder<Slot> slotBuilder = ImmutableList.builder();
|
||||
for (int i = 0; i < generate.getGeneratorOutput().size(); i++) {
|
||||
Function generator = boundFunctionGenerators.get(i);
|
||||
UnboundSlot slot = (UnboundSlot) generate.getGeneratorOutput().get(i);
|
||||
Preconditions.checkState(slot.getNameParts().size() == 2,
|
||||
"the size of nameParts of UnboundSlot in LogicalGenerate must be 2.");
|
||||
Slot boundSlot = new SlotReference(slot.getNameParts().get(1), generator.getDataType(),
|
||||
generator.nullable(), ImmutableList.of(slot.getNameParts().get(0)));
|
||||
slotBuilder.add(boundSlot);
|
||||
}
|
||||
return new LogicalGenerate<>(boundFunctionGenerators, slotBuilder.build(), generate.child());
|
||||
})
|
||||
logicalGenerate().when(Plan::canBind).thenApply(ctx -> {
|
||||
LogicalGenerate<GroupPlan> generate = ctx.root;
|
||||
List<Function> boundSlotGenerators
|
||||
= bindSlot(generate.getGenerators(), generate.children(), ctx.cascadesContext);
|
||||
List<Function> boundFunctionGenerators = boundSlotGenerators.stream()
|
||||
.map(f -> bindTableGeneratingFunction((UnboundFunction) f, ctx.cascadesContext))
|
||||
.collect(Collectors.toList());
|
||||
ImmutableList.Builder<Slot> slotBuilder = ImmutableList.builder();
|
||||
for (int i = 0; i < generate.getGeneratorOutput().size(); i++) {
|
||||
Function generator = boundFunctionGenerators.get(i);
|
||||
UnboundSlot slot = (UnboundSlot) generate.getGeneratorOutput().get(i);
|
||||
Preconditions.checkState(slot.getNameParts().size() == 2,
|
||||
"the size of nameParts of UnboundSlot in LogicalGenerate must be 2.");
|
||||
Slot boundSlot = new SlotReference(slot.getNameParts().get(1), generator.getDataType(),
|
||||
generator.nullable(), ImmutableList.of(slot.getNameParts().get(0)));
|
||||
slotBuilder.add(boundSlot);
|
||||
}
|
||||
return new LogicalGenerate<>(boundFunctionGenerators, slotBuilder.build(), generate.child());
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_UNBOUND_TVF_RELATION_FUNCTION.build(
|
||||
unboundTVFRelation().thenApply(ctx -> {
|
||||
UnboundTVFRelation relation = ctx.root;
|
||||
return bindTableValuedFunction(relation, ctx.statementContext);
|
||||
})
|
||||
),
|
||||
|
||||
// when child update, we need update current plan's logical properties,
|
||||
@ -502,8 +529,9 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
List<OrderKey> sortItemList = sort.getOrderKeys()
|
||||
.stream()
|
||||
.map(orderKey -> {
|
||||
Expression item = bind(orderKey.getExpr(), plan, ctx);
|
||||
item = bind(item, plan.children(), ctx);
|
||||
Expression item = bindSlot(orderKey.getExpr(), plan, ctx);
|
||||
item = bindSlot(item, plan.children(), ctx);
|
||||
item = bindFunction(item, ctx);
|
||||
return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst());
|
||||
}).collect(Collectors.toList());
|
||||
return new LogicalSort<>(sortItemList, sort.child());
|
||||
@ -525,29 +553,36 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
private <E extends Expression> List<E> bind(List<E> exprList, List<Plan> inputs, CascadesContext cascadesContext) {
|
||||
private <E extends Expression> List<E> bindSlot(
|
||||
List<E> exprList, List<Plan> inputs, CascadesContext cascadesContext) {
|
||||
return exprList.stream()
|
||||
.map(expr -> bind(expr, inputs, cascadesContext))
|
||||
.map(expr -> bindSlot(expr, inputs, cascadesContext))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private <E extends Expression> Set<E> bind(Set<E> exprList, List<Plan> inputs, CascadesContext cascadesContext) {
|
||||
private <E extends Expression> Set<E> bindSlot(
|
||||
Set<E> exprList, List<Plan> inputs, CascadesContext cascadesContext) {
|
||||
return exprList.stream()
|
||||
.map(expr -> bind(expr, inputs, cascadesContext))
|
||||
.map(expr -> bindSlot(expr, inputs, cascadesContext))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private <E extends Expression> E bind(E expr, Plan input, CascadesContext cascadesContext) {
|
||||
return (E) new Binder(toScope(input.getOutput()), cascadesContext).bind(expr);
|
||||
private <E extends Expression> E bindSlot(E expr, Plan input, CascadesContext cascadesContext) {
|
||||
return (E) new SlotBinder(toScope(input.getOutput()), cascadesContext).bind(expr);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private <E extends Expression> E bind(E expr, List<Plan> inputs, CascadesContext cascadesContext) {
|
||||
private <E extends Expression> E bindSlot(E expr, List<Plan> inputs, CascadesContext cascadesContext) {
|
||||
List<Slot> boundedSlots = inputs.stream()
|
||||
.flatMap(input -> input.getOutput().stream())
|
||||
.collect(Collectors.toList());
|
||||
return (E) new Binder(toScope(boundedSlots), cascadesContext).bind(expr);
|
||||
return (E) new SlotBinder(toScope(boundedSlots), cascadesContext).bind(expr);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private <E extends Expression> E bindFunction(E expr, CascadesContext cascadesContext) {
|
||||
return (E) FunctionBinder.INSTANCE.bind(expr, cascadesContext);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -578,4 +613,36 @@ public class BindSlotReference implements AnalysisRuleFactory {
|
||||
return slotReference;
|
||||
}
|
||||
}
|
||||
|
||||
private LogicalTVFRelation bindTableValuedFunction(UnboundTVFRelation unboundTVFRelation,
|
||||
StatementContext statementContext) {
|
||||
Env env = statementContext.getConnectContext().getEnv();
|
||||
FunctionRegistry functionRegistry = env.getFunctionRegistry();
|
||||
|
||||
String functionName = unboundTVFRelation.getFunctionName();
|
||||
TVFProperties arguments = unboundTVFRelation.getProperties();
|
||||
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, arguments);
|
||||
BoundFunction function = functionBuilder.build(functionName, arguments);
|
||||
if (!(function instanceof TableValuedFunction)) {
|
||||
throw new AnalysisException(function.toSql() + " is not a TableValuedFunction");
|
||||
}
|
||||
return new LogicalTVFRelation(unboundTVFRelation.getId(), (TableValuedFunction) function);
|
||||
}
|
||||
|
||||
private BoundFunction bindTableGeneratingFunction(UnboundFunction unboundFunction,
|
||||
CascadesContext cascadesContext) {
|
||||
List<Expression> boundArguments = unboundFunction.getArguments().stream()
|
||||
.map(e -> bindFunction(e, cascadesContext))
|
||||
.collect(Collectors.toList());
|
||||
FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry();
|
||||
|
||||
String functionName = unboundFunction.getName();
|
||||
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, boundArguments);
|
||||
BoundFunction function = functionBuilder.build(functionName, boundArguments);
|
||||
if (!(function instanceof TableGeneratingFunction)) {
|
||||
throw new AnalysisException(function.toSql() + " is not a TableGeneratingFunction");
|
||||
}
|
||||
function = (BoundFunction) TypeCoercion.INSTANCE.rewrite(function, null);
|
||||
return function;
|
||||
}
|
||||
}
|
||||
@ -1,269 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.catalog.FunctionRegistry;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
import org.apache.doris.nereids.analyzer.UnboundFunction;
|
||||
import org.apache.doris.nereids.analyzer.UnboundTVFRelation;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.jobs.batch.CheckLegalityBeforeTypeCoercion;
|
||||
import org.apache.doris.nereids.properties.OrderKey;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.CharacterLiteralTypeCoercion;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.TVFProperties;
|
||||
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* BindFunction.
|
||||
*/
|
||||
public class BindFunction implements AnalysisRuleFactory {
|
||||
@Override
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.of(
|
||||
RuleType.BINDING_ONE_ROW_RELATION_FUNCTION.build(
|
||||
logicalOneRowRelation().thenApply(ctx -> {
|
||||
LogicalOneRowRelation oneRowRelation = ctx.root;
|
||||
List<NamedExpression> projects = oneRowRelation.getProjects();
|
||||
List<NamedExpression> boundProjects = bindAndTypeCoercion(projects, ctx.connectContext);
|
||||
if (projects.equals(boundProjects)) {
|
||||
return oneRowRelation;
|
||||
}
|
||||
return new LogicalOneRowRelation(boundProjects);
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_PROJECT_FUNCTION.build(
|
||||
logicalProject().thenApply(ctx -> {
|
||||
LogicalProject<GroupPlan> project = ctx.root;
|
||||
List<NamedExpression> boundExpr = bindAndTypeCoercion(project.getProjects(),
|
||||
ctx.connectContext);
|
||||
return new LogicalProject<>(boundExpr, project.child(), project.isDistinct());
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_AGGREGATE_FUNCTION.build(
|
||||
logicalAggregate().thenApply(ctx -> {
|
||||
LogicalAggregate<GroupPlan> agg = ctx.root;
|
||||
List<Expression> groupBy = bindAndTypeCoercion(agg.getGroupByExpressions(),
|
||||
ctx.connectContext);
|
||||
List<NamedExpression> output = bindAndTypeCoercion(agg.getOutputExpressions(),
|
||||
ctx.connectContext);
|
||||
return agg.withGroupByAndOutput(groupBy, output);
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_REPEAT_FUNCTION.build(
|
||||
logicalRepeat().thenApply(ctx -> {
|
||||
LogicalRepeat<GroupPlan> repeat = ctx.root;
|
||||
List<List<Expression>> groupingSets = repeat.getGroupingSets()
|
||||
.stream()
|
||||
.map(groupingSet -> bindAndTypeCoercion(groupingSet, ctx.connectContext))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
List<NamedExpression> output = bindAndTypeCoercion(repeat.getOutputExpressions(),
|
||||
ctx.connectContext);
|
||||
return repeat.withGroupSetsAndOutput(groupingSets, output);
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_FILTER_FUNCTION.build(
|
||||
logicalFilter().thenApply(ctx -> {
|
||||
LogicalFilter<GroupPlan> filter = ctx.root;
|
||||
Set<Expression> conjuncts = bindAndTypeCoercion(filter.getConjuncts(), ctx.connectContext);
|
||||
return new LogicalFilter<>(conjuncts, filter.child());
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_HAVING_FUNCTION.build(
|
||||
logicalHaving().thenApply(ctx -> {
|
||||
LogicalHaving<GroupPlan> having = ctx.root;
|
||||
Set<Expression> conjuncts = bindAndTypeCoercion(having.getConjuncts(), ctx.connectContext);
|
||||
return new LogicalHaving<>(conjuncts, having.child());
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_SORT_FUNCTION.build(
|
||||
logicalSort().thenApply(ctx -> {
|
||||
LogicalSort<GroupPlan> sort = ctx.root;
|
||||
List<OrderKey> orderKeys = sort.getOrderKeys().stream()
|
||||
.map(orderKey -> new OrderKey(
|
||||
bindAndTypeCoercion(orderKey.getExpr(),
|
||||
ctx.connectContext.getEnv(),
|
||||
new ExpressionRewriteContext(ctx.connectContext)
|
||||
),
|
||||
orderKey.isAsc(),
|
||||
orderKey.isNullFirst())
|
||||
|
||||
)
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
return new LogicalSort<>(orderKeys, sort.child());
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_JOIN_FUNCTION.build(
|
||||
logicalJoin().thenApply(ctx -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = ctx.root;
|
||||
List<Expression> hashConjuncts = bindAndTypeCoercion(join.getHashJoinConjuncts(),
|
||||
ctx.connectContext);
|
||||
List<Expression> otherConjuncts = bindAndTypeCoercion(join.getOtherJoinConjuncts(),
|
||||
ctx.connectContext);
|
||||
return new LogicalJoin<>(join.getJoinType(), hashConjuncts, otherConjuncts,
|
||||
join.getHint(),
|
||||
join.left(), join.right());
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_UNBOUND_TVF_RELATION_FUNCTION.build(
|
||||
unboundTVFRelation().thenApply(ctx -> {
|
||||
UnboundTVFRelation relation = ctx.root;
|
||||
return FunctionBinder.INSTANCE.bindTableValuedFunction(relation, ctx.statementContext);
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private <E extends Expression> List<E> bindAndTypeCoercion(List<? extends E> exprList, ConnectContext ctx) {
|
||||
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx);
|
||||
return exprList.stream()
|
||||
.map(expr -> bindAndTypeCoercion(expr, ctx.getEnv(), rewriteContext))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
private <E extends Expression> E bindAndTypeCoercion(E expr, Env env, ExpressionRewriteContext ctx) {
|
||||
expr = FunctionBinder.INSTANCE.bind(expr, env);
|
||||
expr = (E) expr.accept(CheckLegalityBeforeTypeCoercion.INSTANCE, ctx);
|
||||
expr = (E) CharacterLiteralTypeCoercion.INSTANCE.rewrite(expr, ctx);
|
||||
return (E) TypeCoercion.INSTANCE.rewrite(expr, null);
|
||||
}
|
||||
|
||||
private <E extends Expression> Set<E> bindAndTypeCoercion(Set<? extends E> exprSet, ConnectContext ctx) {
|
||||
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx);
|
||||
return exprSet.stream()
|
||||
.map(expr -> bindAndTypeCoercion(expr, ctx.getEnv(), rewriteContext))
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
|
||||
/**
|
||||
* function binder
|
||||
*/
|
||||
public static class FunctionBinder extends DefaultExpressionRewriter<Env> {
|
||||
public static final FunctionBinder INSTANCE = new FunctionBinder();
|
||||
|
||||
public <E extends Expression> E bind(E expression, Env env) {
|
||||
return (E) expression.accept(this, env);
|
||||
}
|
||||
|
||||
/**
|
||||
* bindTableValuedFunction
|
||||
*/
|
||||
public LogicalTVFRelation bindTableValuedFunction(UnboundTVFRelation unboundTVFRelation,
|
||||
StatementContext statementContext) {
|
||||
Env env = statementContext.getConnectContext().getEnv();
|
||||
FunctionRegistry functionRegistry = env.getFunctionRegistry();
|
||||
|
||||
String functionName = unboundTVFRelation.getFunctionName();
|
||||
TVFProperties arguments = unboundTVFRelation.getProperties();
|
||||
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, arguments);
|
||||
BoundFunction function = functionBuilder.build(functionName, arguments);
|
||||
if (!(function instanceof TableValuedFunction)) {
|
||||
throw new AnalysisException(function.toSql() + " is not a TableValuedFunction");
|
||||
}
|
||||
return new LogicalTVFRelation(unboundTVFRelation.getId(), (TableValuedFunction) function);
|
||||
}
|
||||
|
||||
/**
|
||||
* bindTableGeneratingFunction
|
||||
*/
|
||||
public BoundFunction bindTableGeneratingFunction(UnboundFunction unboundFunction,
|
||||
StatementContext statementContext) {
|
||||
Env env = statementContext.getConnectContext().getEnv();
|
||||
List<Expression> boundArguments = unboundFunction.getArguments().stream()
|
||||
.map(e -> INSTANCE.bind(e, env))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
FunctionRegistry functionRegistry = env.getFunctionRegistry();
|
||||
|
||||
String functionName = unboundFunction.getName();
|
||||
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, boundArguments);
|
||||
BoundFunction function = functionBuilder.build(functionName, boundArguments);
|
||||
if (!(function instanceof TableGeneratingFunction)) {
|
||||
throw new AnalysisException(function.toSql() + " is not a TableGeneratingFunction");
|
||||
}
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitUnboundFunction(UnboundFunction unboundFunction, Env env) {
|
||||
unboundFunction = (UnboundFunction) super.visitUnboundFunction(unboundFunction, env);
|
||||
|
||||
FunctionRegistry functionRegistry = env.getFunctionRegistry();
|
||||
String functionName = unboundFunction.getName();
|
||||
List<Object> arguments = unboundFunction.isDistinct()
|
||||
? ImmutableList.builder()
|
||||
.add(unboundFunction.isDistinct())
|
||||
.addAll(unboundFunction.getArguments())
|
||||
.build()
|
||||
: (List) unboundFunction.getArguments();
|
||||
|
||||
FunctionBuilder builder = functionRegistry.findFunctionBuilder(functionName, arguments);
|
||||
BoundFunction boundFunction = builder.build(functionName, arguments);
|
||||
return boundFunction;
|
||||
}
|
||||
|
||||
/**
|
||||
* gets the method for calculating the time.
|
||||
* e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB
|
||||
*/
|
||||
@Override
|
||||
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env context) {
|
||||
arithmetic = (TimestampArithmetic) super.visitTimestampArithmetic(arithmetic, context);
|
||||
|
||||
String funcOpName;
|
||||
if (arithmetic.getFuncName() == null) {
|
||||
// e.g. YEARS_ADD, MONTHS_SUB
|
||||
funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(),
|
||||
(arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB");
|
||||
} else {
|
||||
funcOpName = arithmetic.getFuncName();
|
||||
}
|
||||
return arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,415 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.catalog.FunctionRegistry;
|
||||
import org.apache.doris.catalog.PrimitiveType;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.analyzer.UnboundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
|
||||
import org.apache.doris.nereids.trees.expressions.BitNot;
|
||||
import org.apache.doris.nereids.trees.expressions.CaseWhen;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Divide;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
|
||||
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.BooleanType;
|
||||
import org.apache.doris.nereids.types.CharType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateType;
|
||||
import org.apache.doris.nereids.types.DecimalV2Type;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.FloatType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
import org.apache.doris.nereids.types.VarcharType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.types.coercion.CharacterType;
|
||||
import org.apache.doris.nereids.types.coercion.FractionalType;
|
||||
import org.apache.doris.nereids.types.coercion.IntegralType;
|
||||
import org.apache.doris.nereids.types.coercion.NumericType;
|
||||
import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.BigInteger;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* function binder
|
||||
*/
|
||||
class FunctionBinder extends DefaultExpressionRewriter<CascadesContext> {
|
||||
public static final FunctionBinder INSTANCE = new FunctionBinder();
|
||||
|
||||
public <E extends Expression> E bind(E expression, CascadesContext context) {
|
||||
return (E) expression.accept(this, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visit(Expression expr, CascadesContext context) {
|
||||
expr = super.visit(expr, context);
|
||||
expr.checkLegalityBeforeTypeCoercion();
|
||||
if (expr instanceof ImplicitCastInputTypes) {
|
||||
List<AbstractDataType> expectedInputTypes = ((ImplicitCastInputTypes) expr).expectedInputTypes();
|
||||
if (!expectedInputTypes.isEmpty()) {
|
||||
return visitImplicitCastInputTypes(expr, expectedInputTypes);
|
||||
}
|
||||
}
|
||||
return expr;
|
||||
}
|
||||
|
||||
/* ********************************************************************************************
|
||||
* bind function
|
||||
* ******************************************************************************************** */
|
||||
|
||||
@Override
|
||||
public Expression visitUnboundFunction(UnboundFunction unboundFunction, CascadesContext context) {
|
||||
unboundFunction = unboundFunction.withChildren(unboundFunction.children().stream()
|
||||
.map(e -> e.accept(this, context)).collect(Collectors.toList()));
|
||||
|
||||
// bind function
|
||||
FunctionRegistry functionRegistry = context.getConnectContext().getEnv().getFunctionRegistry();
|
||||
String functionName = unboundFunction.getName();
|
||||
List<Object> arguments = unboundFunction.isDistinct()
|
||||
? ImmutableList.builder()
|
||||
.add(unboundFunction.isDistinct())
|
||||
.addAll(unboundFunction.getArguments())
|
||||
.build()
|
||||
: (List) unboundFunction.getArguments();
|
||||
|
||||
FunctionBuilder builder = functionRegistry.findFunctionBuilder(functionName, arguments);
|
||||
BoundFunction boundFunction = builder.build(functionName, arguments);
|
||||
|
||||
// check
|
||||
boundFunction.checkLegalityBeforeTypeCoercion();
|
||||
|
||||
// type coercion
|
||||
return visitImplicitCastInputTypes(boundFunction, boundFunction.expectedInputTypes());
|
||||
}
|
||||
|
||||
/**
|
||||
* gets the method for calculating the time.
|
||||
* e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB
|
||||
*/
|
||||
@Override
|
||||
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, CascadesContext context) {
|
||||
Expression left = arithmetic.left().accept(this, context);
|
||||
Expression right = arithmetic.right().accept(this, context);
|
||||
|
||||
// bind function
|
||||
arithmetic = (TimestampArithmetic) arithmetic.withChildren(left, right);
|
||||
String funcOpName;
|
||||
if (arithmetic.getFuncName() == null) {
|
||||
// e.g. YEARS_ADD, MONTHS_SUB
|
||||
funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(),
|
||||
(arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB");
|
||||
} else {
|
||||
funcOpName = arithmetic.getFuncName();
|
||||
}
|
||||
arithmetic = (TimestampArithmetic) arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT));
|
||||
|
||||
// type coercion
|
||||
return visitImplicitCastInputTypes(arithmetic, arithmetic.expectedInputTypes());
|
||||
}
|
||||
|
||||
/* ********************************************************************************************
|
||||
* type coercion
|
||||
* ******************************************************************************************** */
|
||||
|
||||
@Override
|
||||
public Expression visitIntegralDivide(IntegralDivide integralDivide, CascadesContext context) {
|
||||
Expression left = integralDivide.left().accept(this, context);
|
||||
Expression right = integralDivide.right().accept(this, context);
|
||||
|
||||
// check before bind
|
||||
integralDivide.checkLegalityBeforeTypeCoercion();
|
||||
|
||||
// type coercion
|
||||
Expression newLeft = TypeCoercionUtils.castIfNotSameType(left, BigIntType.INSTANCE);
|
||||
Expression newRight = TypeCoercionUtils.castIfNotSameType(right, BigIntType.INSTANCE);
|
||||
return integralDivide.withChildren(newLeft, newRight);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitBinaryOperator(BinaryOperator binaryOperator, CascadesContext context) {
|
||||
Expression left = binaryOperator.left().accept(this, context);
|
||||
Expression right = binaryOperator.right().accept(this, context);
|
||||
// check
|
||||
binaryOperator.checkLegalityBeforeTypeCoercion();
|
||||
|
||||
// characterLiteralTypeCoercion
|
||||
if (left instanceof Literal || right instanceof Literal) {
|
||||
if (left instanceof Literal && ((Literal) left).isCharacterLiteral()) {
|
||||
left = characterLiteralTypeCoercion(((Literal) left).getStringValue(), right.getDataType())
|
||||
.orElse(left);
|
||||
}
|
||||
if (right instanceof Literal && ((Literal) right).isCharacterLiteral()) {
|
||||
right = characterLiteralTypeCoercion(((Literal) right).getStringValue(), left.getDataType())
|
||||
.orElse(right);
|
||||
}
|
||||
}
|
||||
binaryOperator = (BinaryOperator) binaryOperator.withChildren(left, right);
|
||||
|
||||
// type coercion
|
||||
if (binaryOperator instanceof ImplicitCastInputTypes) {
|
||||
List<AbstractDataType> expectedInputTypes = ((ImplicitCastInputTypes) binaryOperator).expectedInputTypes();
|
||||
if (!expectedInputTypes.isEmpty()) {
|
||||
binaryOperator.children().stream().filter(e -> e instanceof StringLikeLiteral)
|
||||
.forEach(expr -> {
|
||||
try {
|
||||
new BigDecimal(((StringLikeLiteral) expr).getStringValue());
|
||||
} catch (NumberFormatException e) {
|
||||
throw new IllegalStateException(String.format(
|
||||
"string literal %s cannot be cast to double", expr.toSql()));
|
||||
}
|
||||
});
|
||||
binaryOperator = (BinaryOperator) visitImplicitCastInputTypes(binaryOperator, expectedInputTypes);
|
||||
}
|
||||
}
|
||||
|
||||
BinaryOperator op = binaryOperator;
|
||||
Expression opLeft = op.left();
|
||||
Expression opRight = op.right();
|
||||
|
||||
return Optional.of(TypeCoercionUtils.canHandleTypeCoercion(left.getDataType(), right.getDataType()))
|
||||
.filter(Boolean::booleanValue)
|
||||
.map(b -> TypeCoercionUtils.findTightestCommonType(op, opLeft.getDataType(), opRight.getDataType()))
|
||||
.filter(Optional::isPresent)
|
||||
.map(Optional::get)
|
||||
.filter(ct -> op.inputType().acceptsType(ct))
|
||||
.filter(ct -> !opLeft.getDataType().equals(ct) || !opRight.getDataType().equals(ct))
|
||||
.map(commonType -> {
|
||||
Expression newLeft = TypeCoercionUtils.castIfNotSameType(opLeft, commonType);
|
||||
Expression newRight = TypeCoercionUtils.castIfNotSameType(opRight, commonType);
|
||||
return op.withChildren(newLeft, newRight);
|
||||
})
|
||||
.orElse(op.withChildren(opLeft, opRight));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitDivide(Divide divide, CascadesContext context) {
|
||||
Expression left = divide.left().accept(this, context);
|
||||
Expression right = divide.right().accept(this, context);
|
||||
divide = (Divide) divide.withChildren(left, right);
|
||||
|
||||
// check
|
||||
divide.checkLegalityBeforeTypeCoercion();
|
||||
|
||||
// type coercion
|
||||
DataType t1 = TypeCoercionUtils.getNumResultType(left.getDataType());
|
||||
DataType t2 = TypeCoercionUtils.getNumResultType(right.getDataType());
|
||||
DataType commonType = TypeCoercionUtils.findCommonNumericsType(t1, t2);
|
||||
|
||||
if (divide.getLegacyOperator() == Operator.DIVIDE
|
||||
&& (commonType.isBigIntType() || commonType.isLargeIntType())) {
|
||||
commonType = DoubleType.INSTANCE;
|
||||
}
|
||||
Expression newLeft = TypeCoercionUtils.castIfNotSameType(left, commonType);
|
||||
Expression newRight = TypeCoercionUtils.castIfNotSameType(right, commonType);
|
||||
return divide.withChildren(newLeft, newRight);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitCaseWhen(CaseWhen caseWhen, CascadesContext context) {
|
||||
List<Expression> rewrittenChildren = caseWhen.children().stream()
|
||||
.map(e -> e.accept(this, context)).collect(Collectors.toList());
|
||||
CaseWhen newCaseWhen = caseWhen.withChildren(rewrittenChildren);
|
||||
|
||||
// check
|
||||
newCaseWhen.checkLegalityBeforeTypeCoercion();
|
||||
|
||||
// type coercion
|
||||
List<DataType> dataTypesForCoercion = newCaseWhen.dataTypesForCoercion();
|
||||
if (dataTypesForCoercion.size() <= 1) {
|
||||
return newCaseWhen;
|
||||
}
|
||||
DataType first = dataTypesForCoercion.get(0);
|
||||
if (dataTypesForCoercion.stream().allMatch(dataType -> dataType.equals(first))) {
|
||||
return newCaseWhen;
|
||||
}
|
||||
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(dataTypesForCoercion);
|
||||
return optionalCommonType
|
||||
.map(commonType -> {
|
||||
List<Expression> newChildren
|
||||
= newCaseWhen.getWhenClauses().stream()
|
||||
.map(wc -> wc.withChildren(wc.getOperand(),
|
||||
TypeCoercionUtils.castIfNotSameType(wc.getResult(), commonType)))
|
||||
.collect(Collectors.toList());
|
||||
newCaseWhen.getDefaultValue()
|
||||
.map(dv -> TypeCoercionUtils.castIfNotSameType(dv, commonType))
|
||||
.ifPresent(newChildren::add);
|
||||
return newCaseWhen.withChildren(newChildren);
|
||||
})
|
||||
.orElse(newCaseWhen);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitInPredicate(InPredicate inPredicate, CascadesContext context) {
|
||||
List<Expression> rewrittenChildren = inPredicate.children().stream()
|
||||
.map(e -> e.accept(this, context)).collect(Collectors.toList());
|
||||
InPredicate newInPredicate = inPredicate.withChildren(rewrittenChildren);
|
||||
|
||||
// check
|
||||
newInPredicate.checkLegalityBeforeTypeCoercion();
|
||||
|
||||
// type coercion
|
||||
if (newInPredicate.getOptions().stream().map(Expression::getDataType)
|
||||
.allMatch(dt -> dt.equals(newInPredicate.getCompareExpr().getDataType()))) {
|
||||
return newInPredicate;
|
||||
}
|
||||
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(newInPredicate.children()
|
||||
.stream().map(Expression::getDataType).collect(Collectors.toList()));
|
||||
|
||||
return optionalCommonType
|
||||
.map(commonType -> {
|
||||
List<Expression> newChildren = newInPredicate.children().stream()
|
||||
.map(e -> TypeCoercionUtils.castIfNotSameType(e, commonType))
|
||||
.collect(Collectors.toList());
|
||||
return newInPredicate.withChildren(newChildren);
|
||||
})
|
||||
.orElse(newInPredicate);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitBitNot(BitNot bitNot, CascadesContext context) {
|
||||
Expression child = bitNot.child().accept(this, context);
|
||||
// check
|
||||
bitNot.checkLegalityBeforeTypeCoercion();
|
||||
|
||||
// type coercion
|
||||
if (child.getDataType().toCatalogDataType().getPrimitiveType().ordinal() > PrimitiveType.LARGEINT.ordinal()) {
|
||||
child = new Cast(child, BigIntType.INSTANCE);
|
||||
}
|
||||
return bitNot.withChildren(child);
|
||||
}
|
||||
|
||||
private Optional<Expression> characterLiteralTypeCoercion(String value, DataType dataType) {
|
||||
Expression ret = null;
|
||||
try {
|
||||
if (dataType instanceof BooleanType) {
|
||||
if ("true".equalsIgnoreCase(value)) {
|
||||
ret = BooleanLiteral.TRUE;
|
||||
}
|
||||
if ("false".equalsIgnoreCase(value)) {
|
||||
ret = BooleanLiteral.FALSE;
|
||||
}
|
||||
} else if (dataType instanceof IntegralType) {
|
||||
BigInteger bigInt = new BigInteger(value);
|
||||
if (BigInteger.valueOf(bigInt.byteValue()).equals(bigInt)) {
|
||||
ret = new TinyIntLiteral(bigInt.byteValue());
|
||||
} else if (BigInteger.valueOf(bigInt.shortValue()).equals(bigInt)) {
|
||||
ret = new SmallIntLiteral(bigInt.shortValue());
|
||||
} else if (BigInteger.valueOf(bigInt.intValue()).equals(bigInt)) {
|
||||
ret = new IntegerLiteral(bigInt.intValue());
|
||||
} else if (BigInteger.valueOf(bigInt.longValue()).equals(bigInt)) {
|
||||
ret = new BigIntLiteral(bigInt.longValueExact());
|
||||
} else {
|
||||
ret = new LargeIntLiteral(bigInt);
|
||||
}
|
||||
} else if (dataType instanceof FloatType) {
|
||||
ret = new FloatLiteral(Float.parseFloat(value));
|
||||
} else if (dataType instanceof DoubleType) {
|
||||
ret = new DoubleLiteral(Double.parseDouble(value));
|
||||
} else if (dataType instanceof DecimalV2Type) {
|
||||
ret = new DecimalLiteral(new BigDecimal(value));
|
||||
} else if (dataType instanceof CharType) {
|
||||
ret = new CharLiteral(value, value.length());
|
||||
} else if (dataType instanceof VarcharType) {
|
||||
ret = new VarcharLiteral(value, value.length());
|
||||
} else if (dataType instanceof StringType) {
|
||||
ret = new StringLiteral(value);
|
||||
} else if (dataType instanceof DateType) {
|
||||
ret = new DateLiteral(value);
|
||||
} else if (dataType instanceof DateTimeType) {
|
||||
ret = new DateTimeLiteral(value);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
// ignore
|
||||
}
|
||||
return Optional.ofNullable(ret);
|
||||
}
|
||||
|
||||
private Expression visitImplicitCastInputTypes(Expression expr, List<AbstractDataType> expectedInputTypes) {
|
||||
List<Optional<DataType>> inputImplicitCastTypes
|
||||
= getInputImplicitCastTypes(expr.children(), expectedInputTypes);
|
||||
return castInputs(expr, inputImplicitCastTypes);
|
||||
}
|
||||
|
||||
private List<Optional<DataType>> getInputImplicitCastTypes(
|
||||
List<Expression> inputs, List<AbstractDataType> expectedTypes) {
|
||||
Builder<Optional<DataType>> implicitCastTypes = ImmutableList.builder();
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
DataType argType = inputs.get(i).getDataType();
|
||||
AbstractDataType expectedType = expectedTypes.get(i);
|
||||
Optional<DataType> castType = TypeCoercionUtils.implicitCast(argType, expectedType);
|
||||
// TODO: complete the cast logic like FunctionCallExpr.analyzeImpl
|
||||
boolean legacyCastCompatible = expectedType instanceof DataType
|
||||
&& !(expectedType.getClass().equals(NumericType.class))
|
||||
&& !(expectedType.getClass().equals(IntegralType.class))
|
||||
&& !(expectedType.getClass().equals(FractionalType.class))
|
||||
&& !(expectedType.getClass().equals(CharacterType.class))
|
||||
&& !argType.toCatalogDataType().matchesType(expectedType.toCatalogDataType());
|
||||
if (!castType.isPresent() && legacyCastCompatible) {
|
||||
castType = Optional.of((DataType) expectedType);
|
||||
}
|
||||
implicitCastTypes.add(castType);
|
||||
}
|
||||
return implicitCastTypes.build();
|
||||
}
|
||||
|
||||
private Expression castInputs(Expression expr, List<Optional<DataType>> castTypes) {
|
||||
return expr.withChildren((child, childIndex) -> {
|
||||
DataType argType = child.getDataType();
|
||||
Optional<DataType> castType = castTypes.get(childIndex);
|
||||
if (castType.isPresent() && !castType.get().equals(argType)) {
|
||||
return TypeCoercionUtils.castIfNotSameType(child, castType.get());
|
||||
} else {
|
||||
return child;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -29,7 +29,6 @@ import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.planner.PlannerContext;
|
||||
|
||||
import org.apache.commons.lang.StringUtils;
|
||||
|
||||
@ -37,9 +36,9 @@ import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
class Binder extends SubExprAnalyzer {
|
||||
class SlotBinder extends SubExprAnalyzer {
|
||||
|
||||
public Binder(Scope scope, CascadesContext cascadesContext) {
|
||||
public SlotBinder(Scope scope, CascadesContext cascadesContext) {
|
||||
super(scope, cascadesContext);
|
||||
}
|
||||
|
||||
@ -48,7 +47,7 @@ class Binder extends SubExprAnalyzer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitUnboundAlias(UnboundAlias unboundAlias, PlannerContext context) {
|
||||
public Expression visitUnboundAlias(UnboundAlias unboundAlias, CascadesContext context) {
|
||||
Expression child = unboundAlias.child().accept(this, context);
|
||||
if (unboundAlias.getAlias().isPresent()) {
|
||||
return new Alias(child, unboundAlias.getAlias().get());
|
||||
@ -62,7 +61,7 @@ class Binder extends SubExprAnalyzer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Slot visitUnboundSlot(UnboundSlot unboundSlot, PlannerContext context) {
|
||||
public Slot visitUnboundSlot(UnboundSlot unboundSlot, CascadesContext context) {
|
||||
Optional<List<Slot>> boundedOpt = Optional.of(bindSlot(unboundSlot, getScope().getSlots()));
|
||||
boolean foundInThisScope = !boundedOpt.get().isEmpty();
|
||||
// Currently only looking for symbols on the previous level.
|
||||
@ -110,7 +109,7 @@ class Binder extends SubExprAnalyzer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitUnboundStar(UnboundStar unboundStar, PlannerContext context) {
|
||||
public Expression visitUnboundStar(UnboundStar unboundStar, CascadesContext context) {
|
||||
List<String> qualifier = unboundStar.getQualifier();
|
||||
List<Slot> slots = getScope().getSlots()
|
||||
.stream()
|
||||
@ -33,7 +33,6 @@ import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewri
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.planner.PlannerContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
@ -45,7 +44,7 @@ import java.util.Optional;
|
||||
/**
|
||||
* Use the visitor to iterate sub expression.
|
||||
*/
|
||||
public class SubExprAnalyzer extends DefaultExpressionRewriter<PlannerContext> {
|
||||
class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
|
||||
|
||||
private final Scope scope;
|
||||
private final CascadesContext cascadesContext;
|
||||
@ -56,7 +55,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter<PlannerContext> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitNot(Not not, PlannerContext context) {
|
||||
public Expression visitNot(Not not, CascadesContext context) {
|
||||
Expression child = not.child();
|
||||
if (child instanceof Exists) {
|
||||
return visitExistsSubquery(
|
||||
@ -69,7 +68,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter<PlannerContext> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitExistsSubquery(Exists exists, PlannerContext context) {
|
||||
public Expression visitExistsSubquery(Exists exists, CascadesContext context) {
|
||||
AnalyzedResult analyzedResult = analyzeSubquery(exists);
|
||||
|
||||
return new Exists(analyzedResult.getLogicalPlan(),
|
||||
@ -77,7 +76,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter<PlannerContext> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitInSubquery(InSubquery expr, PlannerContext context) {
|
||||
public Expression visitInSubquery(InSubquery expr, CascadesContext context) {
|
||||
AnalyzedResult analyzedResult = analyzeSubquery(expr);
|
||||
|
||||
checkOutputColumn(analyzedResult.getLogicalPlan());
|
||||
@ -90,7 +89,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter<PlannerContext> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitScalarSubquery(ScalarSubquery scalar, PlannerContext context) {
|
||||
public Expression visitScalarSubquery(ScalarSubquery scalar, CascadesContext context) {
|
||||
AnalyzedResult analyzedResult = analyzeSubquery(scalar);
|
||||
|
||||
checkOutputColumn(analyzedResult.getLogicalPlan());
|
||||
|
||||
@ -45,6 +45,7 @@ public class ExpressionNormalization extends ExpressionRewrite {
|
||||
BetweenToCompoundRule.INSTANCE,
|
||||
InPredicateToEqualToRule.INSTANCE,
|
||||
SimplifyNotExprRule.INSTANCE,
|
||||
// TODO(morrySnow): remove type coercion from here after we could process subquery type coercion when bind
|
||||
CharacterLiteralTypeCoercion.INSTANCE,
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
TypeCoercion.INSTANCE,
|
||||
|
||||
@ -57,6 +57,7 @@ import java.util.Optional;
|
||||
/**
|
||||
* coercion character literal to another side
|
||||
*/
|
||||
@Deprecated
|
||||
@DependsRules(CheckLegalityBeforeTypeCoercion.class)
|
||||
public class CharacterLiteralTypeCoercion extends AbstractExpressionRewriteRule {
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ package org.apache.doris.nereids.rules.expression.rewrite.rules;
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.catalog.PrimitiveType;
|
||||
import org.apache.doris.nereids.annotation.DependsRules;
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.jobs.batch.CheckLegalityBeforeTypeCoercion;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
|
||||
@ -56,7 +55,7 @@ import java.util.stream.Collectors;
|
||||
* a rule to add implicit cast for expressions.
|
||||
* This class is inspired by spark's TypeCoercion.
|
||||
*/
|
||||
@Developing
|
||||
@Deprecated
|
||||
@DependsRules(CheckLegalityBeforeTypeCoercion.class)
|
||||
public class TypeCoercion extends AbstractExpressionRewriteRule {
|
||||
|
||||
|
||||
@ -21,8 +21,6 @@ import org.apache.doris.catalog.SchemaTable;
|
||||
import org.apache.doris.catalog.TableIf;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.properties.LogicalProperties;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.PlanType;
|
||||
import org.apache.doris.nereids.trees.plans.RelationId;
|
||||
@ -30,8 +28,6 @@ import org.apache.doris.nereids.trees.plans.algebra.Scan;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
@ -59,14 +55,6 @@ public class LogicalSchemaScan extends LogicalRelation implements Scan {
|
||||
return visitor.visitLogicalSchemaScan(this, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Slot> computeNonUserVisibleOutput() {
|
||||
SchemaTable schemaTable = getTable();
|
||||
return schemaTable.getBaseSchema().stream()
|
||||
.map(col -> SlotReference.fromColumn(col, qualified()))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
|
||||
return new LogicalSchemaScan(id, table, qualifier, groupExpression, Optional.of(getLogicalProperties()));
|
||||
|
||||
@ -314,7 +314,7 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt
|
||||
ExceptionChecker.expectThrowsWithMsg(
|
||||
AnalysisException.class,
|
||||
"Aggregate functions in having clause can't be nested:"
|
||||
+ " sum(cast((cast(a1 as DOUBLE) + avg(a2)) as SMALLINT)).",
|
||||
+ " sum((cast(a1 as DOUBLE) + avg(a2))).",
|
||||
() -> PlanChecker.from(connectContext).analyze(
|
||||
"SELECT a1 FROM t1 GROUP BY a1 HAVING SUM(a1 + AVG(a2)) > 0"
|
||||
));
|
||||
@ -322,7 +322,7 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt
|
||||
ExceptionChecker.expectThrowsWithMsg(
|
||||
AnalysisException.class,
|
||||
"Aggregate functions in having clause can't be nested:"
|
||||
+ " sum(cast((cast((a1 + a2) as DOUBLE) + avg(a2)) as INT)).",
|
||||
+ " sum((cast((a1 + a2) as DOUBLE) + avg(a2))).",
|
||||
() -> PlanChecker.from(connectContext).analyze(
|
||||
"SELECT a1 FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + AVG(a2)) > 0"
|
||||
));
|
||||
|
||||
Reference in New Issue
Block a user