[refactor] bind slot and function in one rule (#16288)

1. use one rule to bind slot and function and do type coercion to fix type and nullable error
  a. SUM(a1 + AVG(a2)) when a1 and a2 are TINYINT. Before, the return type was SMALLINT, after this PR will return the right type - DOUBLE.
2. fix runtime filter gnerator bugs - bind runtime filter on wrong join conjuncts.
This commit is contained in:
morrySnow
2023-02-02 15:02:32 +08:00
committed by GitHub
parent 42960ffd08
commit a6c1eaf1d8
13 changed files with 594 additions and 408 deletions

View File

@ -22,7 +22,6 @@ import org.apache.doris.nereids.jobs.batch.AdjustAggregateNullableForEmptySetJob
import org.apache.doris.nereids.jobs.batch.AnalyzeRulesJob;
import org.apache.doris.nereids.jobs.batch.AnalyzeSubqueryRulesJob;
import org.apache.doris.nereids.jobs.batch.CheckAnalysisJob;
import org.apache.doris.nereids.jobs.batch.TypeCoercionJob;
import java.util.Objects;
import java.util.Optional;
@ -52,7 +51,6 @@ public class NereidsAnalyzer {
new AnalyzeRulesJob(cascadesContext, outerScope).execute();
new AnalyzeSubqueryRulesJob(cascadesContext).execute();
new AdjustAggregateNullableForEmptySetJob(cascadesContext).execute();
new TypeCoercionJob(cascadesContext).execute();
// check whether analyze result is meaningful
new CheckAnalysisJob(cascadesContext).execute();
}

View File

@ -20,9 +20,8 @@ package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
import org.apache.doris.nereids.rules.analysis.BindFunction;
import org.apache.doris.nereids.rules.analysis.BindExpression;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindSlotReference;
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
@ -32,9 +31,6 @@ import org.apache.doris.nereids.rules.analysis.RegisterCTE;
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
import org.apache.doris.nereids.rules.analysis.ResolveOrdinalInOrderByAndGroupBy;
import org.apache.doris.nereids.rules.analysis.UserAuthentication;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.rewrite.rules.CharacterLiteralTypeCoercion;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
import org.apache.doris.nereids.rules.rewrite.logical.HideOneRowRelationUnderUnion;
import com.google.common.collect.ImmutableList;
@ -61,8 +57,7 @@ public class AnalyzeRulesJob extends BatchRulesJob {
new BindRelation(),
new CheckPolicy(),
new UserAuthentication(),
new BindSlotReference(scope),
new BindFunction(),
new BindExpression(scope),
new ProjectToGlobalAggregate(),
// this rule check's the logicalProject node's isDisinct property
// and replace the logicalProject node with a LogicalAggregate node
@ -73,9 +68,7 @@ public class AnalyzeRulesJob extends BatchRulesJob {
new AvgDistinctToSumDivCount(),
new ResolveOrdinalInOrderByAndGroupBy(),
new ReplaceExpressionByChildOutput(),
new HideOneRowRelationUnderUnion(),
new ExpressionNormalization(cascadesContext.getConnectContext(),
ImmutableList.of(CharacterLiteralTypeCoercion.INSTANCE, TypeCoercion.INSTANCE))
new HideOneRowRelationUnderUnion()
),
topDownBatch(
new FillUpMissingSlots(),

View File

@ -45,7 +45,6 @@ import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
/**
@ -89,31 +88,27 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
List<TRuntimeFilterType> legalTypes = Arrays.stream(TRuntimeFilterType.values())
.filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0)
.collect(Collectors.toList());
AtomicInteger cnt = new AtomicInteger();
join.getHashJoinConjuncts().stream()
.map(EqualTo.class::cast)
// TODO: some complex situation cannot be handled now, see testPushDownThroughJoin.
// TODO: we will support it in later version.
.forEach(expr -> legalTypes.forEach(type -> {
Pair<Expression, Expression> normalizedChildren = checkAndMaybeSwapChild(expr, join);
// aliasTransMap doesn't contain the key, means that the path from the olap scan to the join
// contains join with denied join type. for example: a left join b on a.id = b.id
if (normalizedChildren == null
|| !aliasTransferMap.containsKey((Slot) normalizedChildren.first)) {
return;
}
Pair<Slot, Slot> slots = Pair.of(
aliasTransferMap.get((Slot) normalizedChildren.first).second.toSlot(),
((Slot) normalizedChildren.second));
RuntimeFilter filter = new RuntimeFilter(generator.getNextId(),
slots.second, slots.first, type,
cnt.getAndIncrement(), join);
ctx.addJoinToTargetMap(join, slots.first.getExprId());
ctx.setTargetExprIdToFilter(slots.first.getExprId(), filter);
ctx.setTargetsOnScanNode(
aliasTransferMap.get((Slot) normalizedChildren.first).first,
slots.first);
}));
// TODO: some complex situation cannot be handled now, see testPushDownThroughJoin.
// TODO: we will support it in later version.
for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) {
EqualTo expr = (EqualTo) join.getHashJoinConjuncts().get(i);
for (TRuntimeFilterType type : legalTypes) {
Pair<Expression, Expression> normalizedChildren = checkAndMaybeSwapChild(expr, join);
// aliasTransMap doesn't contain the key, means that the path from the olap scan to the join
// contains join with denied join type. for example: a left join b on a.id = b.id
if (normalizedChildren == null || !aliasTransferMap.containsKey((Slot) normalizedChildren.first)) {
continue;
}
Pair<Slot, Slot> slots = Pair.of(
aliasTransferMap.get((Slot) normalizedChildren.first).second.toSlot(),
((Slot) normalizedChildren.second));
RuntimeFilter filter = new RuntimeFilter(generator.getNextId(),
slots.second, slots.first, type, i, join);
ctx.addJoinToTargetMap(join, slots.first.getExprId());
ctx.setTargetExprIdToFilter(slots.first.getExprId(), filter);
ctx.setTargetsOnScanNode(aliasTransferMap.get((Slot) normalizedChildren.first).first, slots.first);
}
}
}
return join;
}

View File

@ -17,17 +17,20 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.catalog.BuiltinAggregateFunctions;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.analyzer.UnboundOneRowRelation;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.analyzer.UnboundTVFRelation;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.BindFunction.FunctionBinder;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.BoundStar;
@ -36,7 +39,13 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.TVFProperties;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
@ -57,6 +66,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
import org.apache.doris.nereids.trees.plans.logical.UsingJoin;
import com.google.common.base.Preconditions;
@ -81,11 +91,11 @@ import java.util.stream.Stream;
* BindSlotReference.
*/
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class BindSlotReference implements AnalysisRuleFactory {
public class BindExpression implements AnalysisRuleFactory {
private final Optional<Scope> outerScope;
public BindSlotReference(Optional<Scope> outerScope) {
public BindExpression(Optional<Scope> outerScope) {
this.outerScope = Objects.requireNonNull(outerScope, "outerScope cannot be null");
}
@ -104,18 +114,23 @@ public class BindSlotReference implements AnalysisRuleFactory {
logicalProject().when(Plan::canBind).thenApply(ctx -> {
LogicalProject<GroupPlan> project = ctx.root;
List<NamedExpression> boundProjections =
bind(project.getProjects(), project.children(), ctx.cascadesContext);
List<NamedExpression> boundExceptions = bind(project.getExcepts(), project.children(),
bindSlot(project.getProjects(), project.children(), ctx.cascadesContext);
List<NamedExpression> boundExceptions = bindSlot(project.getExcepts(), project.children(),
ctx.cascadesContext);
boundProjections = flatBoundStar(boundProjections, boundExceptions);
boundProjections = boundProjections.stream()
.map(expr -> bindFunction(expr, ctx.cascadesContext))
.collect(ImmutableList.toImmutableList());
return new LogicalProject<>(boundProjections, project.child(), project.isDistinct());
})
),
RuleType.BINDING_FILTER_SLOT.build(
logicalFilter().when(Plan::canBind).thenApply(ctx -> {
LogicalFilter<GroupPlan> filter = ctx.root;
Set<Expression> boundConjuncts
= bind(filter.getConjuncts(), filter.children(), ctx.cascadesContext);
Set<Expression> boundConjuncts = filter.getConjuncts().stream()
.map(expr -> bindSlot(expr, filter.children(), ctx.cascadesContext))
.map(expr -> bindFunction(expr, ctx.cascadesContext))
.collect(Collectors.toSet());
return new LogicalFilter<>(boundConjuncts, filter.child());
})
),
@ -141,7 +156,7 @@ public class BindSlotReference implements AnalysisRuleFactory {
.peek(s -> slotNames.add(s.getName()))
.collect(Collectors.toList()));
for (Expression unboundSlot : unboundSlots) {
Expression expression = new Binder(scope, ctx.cascadesContext).bind(unboundSlot);
Expression expression = new SlotBinder(scope, ctx.cascadesContext).bind(unboundSlot);
leftSlots.add(expression);
}
slotNames.clear();
@ -151,7 +166,7 @@ public class BindSlotReference implements AnalysisRuleFactory {
.collect(Collectors.toList()));
List<Expression> rightSlots = new ArrayList<>();
for (Expression unboundSlot : unboundSlots) {
Expression expression = new Binder(scope, ctx.cascadesContext).bind(unboundSlot);
Expression expression = new SlotBinder(scope, ctx.cascadesContext).bind(unboundSlot);
rightSlots.add(expression);
}
int size = leftSlots.size();
@ -166,10 +181,12 @@ public class BindSlotReference implements AnalysisRuleFactory {
logicalJoin().when(Plan::canBind).thenApply(ctx -> {
LogicalJoin<GroupPlan, GroupPlan> join = ctx.root;
List<Expression> cond = join.getOtherJoinConjuncts().stream()
.map(expr -> bind(expr, join.children(), ctx.cascadesContext))
.map(expr -> bindSlot(expr, join.children(), ctx.cascadesContext))
.map(expr -> bindFunction(expr, ctx.cascadesContext))
.collect(Collectors.toList());
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts().stream()
.map(expr -> bind(expr, join.children(), ctx.cascadesContext))
.map(expr -> bindSlot(expr, join.children(), ctx.cascadesContext))
.map(expr -> bindFunction(expr, ctx.cascadesContext))
.collect(Collectors.toList());
return new LogicalJoin<>(join.getJoinType(),
hashJoinConjuncts, cond, join.getHint(), join.left(), join.right());
@ -178,8 +195,10 @@ public class BindSlotReference implements AnalysisRuleFactory {
RuleType.BINDING_AGGREGATE_SLOT.build(
logicalAggregate().when(Plan::canBind).thenApply(ctx -> {
LogicalAggregate<GroupPlan> agg = ctx.root;
List<NamedExpression> output =
bind(agg.getOutputExpressions(), agg.children(), ctx.cascadesContext);
List<NamedExpression> output = agg.getOutputExpressions().stream()
.map(expr -> bindSlot(expr, agg.children(), ctx.cascadesContext))
.map(expr -> bindFunction(expr, ctx.cascadesContext))
.collect(ImmutableList.toImmutableList());
// The columns referenced in group by are first obtained from the child's output,
// and then from the node's output
@ -235,17 +254,8 @@ public class BindSlotReference implements AnalysisRuleFactory {
.filter(ne -> ne instanceof Alias)
.map(Alias.class::cast)
// agg function cannot be bound with group_by_key
// TODO(morrySnow): after bind function moved here,
// we could just use instanceof AggregateFunction
.filter(alias -> !alias.child().anyMatch(expr -> {
if (expr instanceof UnboundFunction) {
UnboundFunction unboundFunction = (UnboundFunction) expr;
return BuiltinAggregateFunctions.INSTANCE.aggFuncNames.contains(
unboundFunction.getName().toLowerCase());
}
return false;
}
))
.filter(alias -> !alias.child()
.anyMatch(expr -> expr instanceof AggregateFunction))
.forEach(alias -> childOutputsToExpr.putIfAbsent(alias.getName(), alias.child()));
List<Expression> replacedGroupBy = agg.getGroupByExpressions().stream()
@ -280,9 +290,10 @@ public class BindSlotReference implements AnalysisRuleFactory {
.collect(Collectors.toSet());
boundSlots.addAll(outputSlots);
Binder binder = new Binder(toScope(Lists.newArrayList(boundSlots)), ctx.cascadesContext);
Binder childBinder
= new Binder(toScope(new ArrayList<>(agg.child().getOutputSet())), ctx.cascadesContext);
SlotBinder binder = new SlotBinder(
toScope(Lists.newArrayList(boundSlots)), ctx.cascadesContext);
SlotBinder childBinder = new SlotBinder(
toScope(new ArrayList<>(agg.child().getOutputSet())), ctx.cascadesContext);
List<Expression> groupBy = replacedGroupBy.stream()
.map(expression -> {
@ -305,15 +316,19 @@ public class BindSlotReference implements AnalysisRuleFactory {
if (hasUnBound.test(groupBy)) {
throw new AnalysisException("cannot bind GROUP BY KEY: " + unboundGroupBys.get(0).toSql());
}
groupBy = groupBy.stream()
.map(expr -> bindFunction(expr, ctx.cascadesContext))
.collect(ImmutableList.toImmutableList());
return agg.withGroupByAndOutput(groupBy, output);
})
),
RuleType.BINDING_REPEAT_SLOT.build(
logicalRepeat().when(Plan::canBind).thenApply(ctx -> {
LogicalRepeat<GroupPlan> repeat = ctx.root;
List<NamedExpression> output =
bind(repeat.getOutputExpressions(), repeat.children(), ctx.cascadesContext);
List<NamedExpression> output = repeat.getOutputExpressions().stream()
.map(expr -> bindSlot(expr, repeat.children(), ctx.cascadesContext))
.map(expr -> bindFunction(expr, ctx.cascadesContext))
.collect(ImmutableList.toImmutableList());
// The columns referenced in group by are first obtained from the child's output,
// and then from the node's output
@ -343,7 +358,10 @@ public class BindSlotReference implements AnalysisRuleFactory {
List<List<Expression>> groupingSets = replacedGroupingSets
.stream()
.map(groupingSet -> bind(groupingSet, repeat.children(), ctx.cascadesContext))
.map(groupingSet -> groupingSet.stream()
.map(expr -> bindSlot(expr, repeat.children(), ctx.cascadesContext))
.map(expr -> bindFunction(expr, ctx.cascadesContext))
.collect(ImmutableList.toImmutableList()))
.collect(ImmutableList.toImmutableList());
List<NamedExpression> newOutput = adjustNullableForRepeat(groupingSets, output);
return repeat.withGroupSetsAndOutput(groupingSets, newOutput);
@ -383,7 +401,8 @@ public class BindSlotReference implements AnalysisRuleFactory {
List<OrderKey> sortItemList = sort.getOrderKeys()
.stream()
.map(orderKey -> {
Expression item = bind(orderKey.getExpr(), sort.child(), ctx.cascadesContext);
Expression item = bindSlot(orderKey.getExpr(), sort.child(), ctx.cascadesContext);
item = bindFunction(item, ctx.cascadesContext);
return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst());
}).collect(Collectors.toList());
return new LogicalSort<>(sortItemList, sort.child());
@ -393,11 +412,12 @@ public class BindSlotReference implements AnalysisRuleFactory {
logicalHaving(aggregate()).when(Plan::canBind).thenApply(ctx -> {
LogicalHaving<Aggregate<GroupPlan>> having = ctx.root;
Plan childPlan = having.child();
Set<Expression> boundConjuncts = having.getConjuncts().stream().map(
expr -> {
expr = bind(expr, childPlan.children(), ctx.cascadesContext);
return bind(expr, childPlan, ctx.cascadesContext);
Set<Expression> boundConjuncts = having.getConjuncts().stream()
.map(expr -> {
expr = bindSlot(expr, childPlan.children(), ctx.cascadesContext);
return bindSlot(expr, childPlan, ctx.cascadesContext);
})
.map(expr -> bindFunction(expr, ctx.cascadesContext))
.collect(Collectors.toSet());
return new LogicalHaving<>(boundConjuncts, having.child());
})
@ -406,11 +426,12 @@ public class BindSlotReference implements AnalysisRuleFactory {
logicalHaving(any()).when(Plan::canBind).thenApply(ctx -> {
LogicalHaving<Plan> having = ctx.root;
Plan childPlan = having.child();
Set<Expression> boundConjuncts = having.getConjuncts().stream().map(
expr -> {
expr = bind(expr, childPlan, ctx.cascadesContext);
return bind(expr, childPlan.children(), ctx.cascadesContext);
})
Set<Expression> boundConjuncts = having.getConjuncts().stream()
.map(expr -> {
expr = bindSlot(expr, childPlan, ctx.cascadesContext);
return bindSlot(expr, childPlan.children(), ctx.cascadesContext);
})
.map(expr -> bindFunction(expr, ctx.cascadesContext))
.collect(Collectors.toSet());
return new LogicalHaving<>(boundConjuncts, having.child());
})
@ -421,7 +442,8 @@ public class BindSlotReference implements AnalysisRuleFactory {
UnboundOneRowRelation oneRowRelation = ctx.root;
List<NamedExpression> projects = oneRowRelation.getProjects()
.stream()
.map(project -> bind(project, ImmutableList.of(), ctx.cascadesContext))
.map(project -> bindSlot(project, ImmutableList.of(), ctx.cascadesContext))
.map(project -> bindFunction(project, ctx.cascadesContext))
.collect(Collectors.toList());
return new LogicalOneRowRelation(projects);
})
@ -450,26 +472,31 @@ public class BindSlotReference implements AnalysisRuleFactory {
})
),
RuleType.BINDING_GENERATE_SLOT.build(
logicalGenerate().when(Plan::canBind).thenApply(ctx -> {
LogicalGenerate<GroupPlan> generate = ctx.root;
List<Function> boundSlotGenerators
= bind(generate.getGenerators(), generate.children(), ctx.cascadesContext);
List<Function> boundFunctionGenerators = boundSlotGenerators.stream()
.map(f -> FunctionBinder.INSTANCE.bindTableGeneratingFunction(
(UnboundFunction) f, ctx.statementContext))
.collect(Collectors.toList());
ImmutableList.Builder<Slot> slotBuilder = ImmutableList.builder();
for (int i = 0; i < generate.getGeneratorOutput().size(); i++) {
Function generator = boundFunctionGenerators.get(i);
UnboundSlot slot = (UnboundSlot) generate.getGeneratorOutput().get(i);
Preconditions.checkState(slot.getNameParts().size() == 2,
"the size of nameParts of UnboundSlot in LogicalGenerate must be 2.");
Slot boundSlot = new SlotReference(slot.getNameParts().get(1), generator.getDataType(),
generator.nullable(), ImmutableList.of(slot.getNameParts().get(0)));
slotBuilder.add(boundSlot);
}
return new LogicalGenerate<>(boundFunctionGenerators, slotBuilder.build(), generate.child());
})
logicalGenerate().when(Plan::canBind).thenApply(ctx -> {
LogicalGenerate<GroupPlan> generate = ctx.root;
List<Function> boundSlotGenerators
= bindSlot(generate.getGenerators(), generate.children(), ctx.cascadesContext);
List<Function> boundFunctionGenerators = boundSlotGenerators.stream()
.map(f -> bindTableGeneratingFunction((UnboundFunction) f, ctx.cascadesContext))
.collect(Collectors.toList());
ImmutableList.Builder<Slot> slotBuilder = ImmutableList.builder();
for (int i = 0; i < generate.getGeneratorOutput().size(); i++) {
Function generator = boundFunctionGenerators.get(i);
UnboundSlot slot = (UnboundSlot) generate.getGeneratorOutput().get(i);
Preconditions.checkState(slot.getNameParts().size() == 2,
"the size of nameParts of UnboundSlot in LogicalGenerate must be 2.");
Slot boundSlot = new SlotReference(slot.getNameParts().get(1), generator.getDataType(),
generator.nullable(), ImmutableList.of(slot.getNameParts().get(0)));
slotBuilder.add(boundSlot);
}
return new LogicalGenerate<>(boundFunctionGenerators, slotBuilder.build(), generate.child());
})
),
RuleType.BINDING_UNBOUND_TVF_RELATION_FUNCTION.build(
unboundTVFRelation().thenApply(ctx -> {
UnboundTVFRelation relation = ctx.root;
return bindTableValuedFunction(relation, ctx.statementContext);
})
),
// when child update, we need update current plan's logical properties,
@ -502,8 +529,9 @@ public class BindSlotReference implements AnalysisRuleFactory {
List<OrderKey> sortItemList = sort.getOrderKeys()
.stream()
.map(orderKey -> {
Expression item = bind(orderKey.getExpr(), plan, ctx);
item = bind(item, plan.children(), ctx);
Expression item = bindSlot(orderKey.getExpr(), plan, ctx);
item = bindSlot(item, plan.children(), ctx);
item = bindFunction(item, ctx);
return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst());
}).collect(Collectors.toList());
return new LogicalSort<>(sortItemList, sort.child());
@ -525,29 +553,36 @@ public class BindSlotReference implements AnalysisRuleFactory {
.collect(ImmutableList.toImmutableList());
}
private <E extends Expression> List<E> bind(List<E> exprList, List<Plan> inputs, CascadesContext cascadesContext) {
private <E extends Expression> List<E> bindSlot(
List<E> exprList, List<Plan> inputs, CascadesContext cascadesContext) {
return exprList.stream()
.map(expr -> bind(expr, inputs, cascadesContext))
.map(expr -> bindSlot(expr, inputs, cascadesContext))
.collect(Collectors.toList());
}
private <E extends Expression> Set<E> bind(Set<E> exprList, List<Plan> inputs, CascadesContext cascadesContext) {
private <E extends Expression> Set<E> bindSlot(
Set<E> exprList, List<Plan> inputs, CascadesContext cascadesContext) {
return exprList.stream()
.map(expr -> bind(expr, inputs, cascadesContext))
.map(expr -> bindSlot(expr, inputs, cascadesContext))
.collect(Collectors.toSet());
}
@SuppressWarnings("unchecked")
private <E extends Expression> E bind(E expr, Plan input, CascadesContext cascadesContext) {
return (E) new Binder(toScope(input.getOutput()), cascadesContext).bind(expr);
private <E extends Expression> E bindSlot(E expr, Plan input, CascadesContext cascadesContext) {
return (E) new SlotBinder(toScope(input.getOutput()), cascadesContext).bind(expr);
}
@SuppressWarnings("unchecked")
private <E extends Expression> E bind(E expr, List<Plan> inputs, CascadesContext cascadesContext) {
private <E extends Expression> E bindSlot(E expr, List<Plan> inputs, CascadesContext cascadesContext) {
List<Slot> boundedSlots = inputs.stream()
.flatMap(input -> input.getOutput().stream())
.collect(Collectors.toList());
return (E) new Binder(toScope(boundedSlots), cascadesContext).bind(expr);
return (E) new SlotBinder(toScope(boundedSlots), cascadesContext).bind(expr);
}
@SuppressWarnings("unchecked")
private <E extends Expression> E bindFunction(E expr, CascadesContext cascadesContext) {
return (E) FunctionBinder.INSTANCE.bind(expr, cascadesContext);
}
/**
@ -578,4 +613,36 @@ public class BindSlotReference implements AnalysisRuleFactory {
return slotReference;
}
}
private LogicalTVFRelation bindTableValuedFunction(UnboundTVFRelation unboundTVFRelation,
StatementContext statementContext) {
Env env = statementContext.getConnectContext().getEnv();
FunctionRegistry functionRegistry = env.getFunctionRegistry();
String functionName = unboundTVFRelation.getFunctionName();
TVFProperties arguments = unboundTVFRelation.getProperties();
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, arguments);
BoundFunction function = functionBuilder.build(functionName, arguments);
if (!(function instanceof TableValuedFunction)) {
throw new AnalysisException(function.toSql() + " is not a TableValuedFunction");
}
return new LogicalTVFRelation(unboundTVFRelation.getId(), (TableValuedFunction) function);
}
private BoundFunction bindTableGeneratingFunction(UnboundFunction unboundFunction,
CascadesContext cascadesContext) {
List<Expression> boundArguments = unboundFunction.getArguments().stream()
.map(e -> bindFunction(e, cascadesContext))
.collect(Collectors.toList());
FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry();
String functionName = unboundFunction.getName();
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, boundArguments);
BoundFunction function = functionBuilder.build(functionName, boundArguments);
if (!(function instanceof TableGeneratingFunction)) {
throw new AnalysisException(function.toSql() + " is not a TableGeneratingFunction");
}
function = (BoundFunction) TypeCoercion.INSTANCE.rewrite(function, null);
return function;
}
}

View File

@ -1,269 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.analyzer.UnboundTVFRelation;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.batch.CheckLegalityBeforeTypeCoercion;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rewrite.rules.CharacterLiteralTypeCoercion;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.TVFProperties;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
/**
* BindFunction.
*/
public class BindFunction implements AnalysisRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.BINDING_ONE_ROW_RELATION_FUNCTION.build(
logicalOneRowRelation().thenApply(ctx -> {
LogicalOneRowRelation oneRowRelation = ctx.root;
List<NamedExpression> projects = oneRowRelation.getProjects();
List<NamedExpression> boundProjects = bindAndTypeCoercion(projects, ctx.connectContext);
if (projects.equals(boundProjects)) {
return oneRowRelation;
}
return new LogicalOneRowRelation(boundProjects);
})
),
RuleType.BINDING_PROJECT_FUNCTION.build(
logicalProject().thenApply(ctx -> {
LogicalProject<GroupPlan> project = ctx.root;
List<NamedExpression> boundExpr = bindAndTypeCoercion(project.getProjects(),
ctx.connectContext);
return new LogicalProject<>(boundExpr, project.child(), project.isDistinct());
})
),
RuleType.BINDING_AGGREGATE_FUNCTION.build(
logicalAggregate().thenApply(ctx -> {
LogicalAggregate<GroupPlan> agg = ctx.root;
List<Expression> groupBy = bindAndTypeCoercion(agg.getGroupByExpressions(),
ctx.connectContext);
List<NamedExpression> output = bindAndTypeCoercion(agg.getOutputExpressions(),
ctx.connectContext);
return agg.withGroupByAndOutput(groupBy, output);
})
),
RuleType.BINDING_REPEAT_FUNCTION.build(
logicalRepeat().thenApply(ctx -> {
LogicalRepeat<GroupPlan> repeat = ctx.root;
List<List<Expression>> groupingSets = repeat.getGroupingSets()
.stream()
.map(groupingSet -> bindAndTypeCoercion(groupingSet, ctx.connectContext))
.collect(ImmutableList.toImmutableList());
List<NamedExpression> output = bindAndTypeCoercion(repeat.getOutputExpressions(),
ctx.connectContext);
return repeat.withGroupSetsAndOutput(groupingSets, output);
})
),
RuleType.BINDING_FILTER_FUNCTION.build(
logicalFilter().thenApply(ctx -> {
LogicalFilter<GroupPlan> filter = ctx.root;
Set<Expression> conjuncts = bindAndTypeCoercion(filter.getConjuncts(), ctx.connectContext);
return new LogicalFilter<>(conjuncts, filter.child());
})
),
RuleType.BINDING_HAVING_FUNCTION.build(
logicalHaving().thenApply(ctx -> {
LogicalHaving<GroupPlan> having = ctx.root;
Set<Expression> conjuncts = bindAndTypeCoercion(having.getConjuncts(), ctx.connectContext);
return new LogicalHaving<>(conjuncts, having.child());
})
),
RuleType.BINDING_SORT_FUNCTION.build(
logicalSort().thenApply(ctx -> {
LogicalSort<GroupPlan> sort = ctx.root;
List<OrderKey> orderKeys = sort.getOrderKeys().stream()
.map(orderKey -> new OrderKey(
bindAndTypeCoercion(orderKey.getExpr(),
ctx.connectContext.getEnv(),
new ExpressionRewriteContext(ctx.connectContext)
),
orderKey.isAsc(),
orderKey.isNullFirst())
)
.collect(ImmutableList.toImmutableList());
return new LogicalSort<>(orderKeys, sort.child());
})
),
RuleType.BINDING_JOIN_FUNCTION.build(
logicalJoin().thenApply(ctx -> {
LogicalJoin<GroupPlan, GroupPlan> join = ctx.root;
List<Expression> hashConjuncts = bindAndTypeCoercion(join.getHashJoinConjuncts(),
ctx.connectContext);
List<Expression> otherConjuncts = bindAndTypeCoercion(join.getOtherJoinConjuncts(),
ctx.connectContext);
return new LogicalJoin<>(join.getJoinType(), hashConjuncts, otherConjuncts,
join.getHint(),
join.left(), join.right());
})
),
RuleType.BINDING_UNBOUND_TVF_RELATION_FUNCTION.build(
unboundTVFRelation().thenApply(ctx -> {
UnboundTVFRelation relation = ctx.root;
return FunctionBinder.INSTANCE.bindTableValuedFunction(relation, ctx.statementContext);
})
)
);
}
private <E extends Expression> List<E> bindAndTypeCoercion(List<? extends E> exprList, ConnectContext ctx) {
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx);
return exprList.stream()
.map(expr -> bindAndTypeCoercion(expr, ctx.getEnv(), rewriteContext))
.collect(ImmutableList.toImmutableList());
}
private <E extends Expression> E bindAndTypeCoercion(E expr, Env env, ExpressionRewriteContext ctx) {
expr = FunctionBinder.INSTANCE.bind(expr, env);
expr = (E) expr.accept(CheckLegalityBeforeTypeCoercion.INSTANCE, ctx);
expr = (E) CharacterLiteralTypeCoercion.INSTANCE.rewrite(expr, ctx);
return (E) TypeCoercion.INSTANCE.rewrite(expr, null);
}
private <E extends Expression> Set<E> bindAndTypeCoercion(Set<? extends E> exprSet, ConnectContext ctx) {
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx);
return exprSet.stream()
.map(expr -> bindAndTypeCoercion(expr, ctx.getEnv(), rewriteContext))
.collect(ImmutableSet.toImmutableSet());
}
/**
* function binder
*/
public static class FunctionBinder extends DefaultExpressionRewriter<Env> {
public static final FunctionBinder INSTANCE = new FunctionBinder();
public <E extends Expression> E bind(E expression, Env env) {
return (E) expression.accept(this, env);
}
/**
* bindTableValuedFunction
*/
public LogicalTVFRelation bindTableValuedFunction(UnboundTVFRelation unboundTVFRelation,
StatementContext statementContext) {
Env env = statementContext.getConnectContext().getEnv();
FunctionRegistry functionRegistry = env.getFunctionRegistry();
String functionName = unboundTVFRelation.getFunctionName();
TVFProperties arguments = unboundTVFRelation.getProperties();
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, arguments);
BoundFunction function = functionBuilder.build(functionName, arguments);
if (!(function instanceof TableValuedFunction)) {
throw new AnalysisException(function.toSql() + " is not a TableValuedFunction");
}
return new LogicalTVFRelation(unboundTVFRelation.getId(), (TableValuedFunction) function);
}
/**
* bindTableGeneratingFunction
*/
public BoundFunction bindTableGeneratingFunction(UnboundFunction unboundFunction,
StatementContext statementContext) {
Env env = statementContext.getConnectContext().getEnv();
List<Expression> boundArguments = unboundFunction.getArguments().stream()
.map(e -> INSTANCE.bind(e, env))
.collect(ImmutableList.toImmutableList());
FunctionRegistry functionRegistry = env.getFunctionRegistry();
String functionName = unboundFunction.getName();
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder(functionName, boundArguments);
BoundFunction function = functionBuilder.build(functionName, boundArguments);
if (!(function instanceof TableGeneratingFunction)) {
throw new AnalysisException(function.toSql() + " is not a TableGeneratingFunction");
}
return function;
}
@Override
public Expression visitUnboundFunction(UnboundFunction unboundFunction, Env env) {
unboundFunction = (UnboundFunction) super.visitUnboundFunction(unboundFunction, env);
FunctionRegistry functionRegistry = env.getFunctionRegistry();
String functionName = unboundFunction.getName();
List<Object> arguments = unboundFunction.isDistinct()
? ImmutableList.builder()
.add(unboundFunction.isDistinct())
.addAll(unboundFunction.getArguments())
.build()
: (List) unboundFunction.getArguments();
FunctionBuilder builder = functionRegistry.findFunctionBuilder(functionName, arguments);
BoundFunction boundFunction = builder.build(functionName, arguments);
return boundFunction;
}
/**
* gets the method for calculating the time.
* e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB
*/
@Override
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env context) {
arithmetic = (TimestampArithmetic) super.visitTimestampArithmetic(arithmetic, context);
String funcOpName;
if (arithmetic.getFuncName() == null) {
// e.g. YEARS_ADD, MONTHS_SUB
funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(),
(arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB");
} else {
funcOpName = arithmetic.getFuncName();
}
return arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT));
}
}
}

View File

@ -0,0 +1,415 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.BitNot;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.FractionalType;
import org.apache.doris.nereids.types.coercion.IntegralType;
import org.apache.doris.nereids.types.coercion.NumericType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* function binder
*/
class FunctionBinder extends DefaultExpressionRewriter<CascadesContext> {
public static final FunctionBinder INSTANCE = new FunctionBinder();
public <E extends Expression> E bind(E expression, CascadesContext context) {
return (E) expression.accept(this, context);
}
@Override
public Expression visit(Expression expr, CascadesContext context) {
expr = super.visit(expr, context);
expr.checkLegalityBeforeTypeCoercion();
if (expr instanceof ImplicitCastInputTypes) {
List<AbstractDataType> expectedInputTypes = ((ImplicitCastInputTypes) expr).expectedInputTypes();
if (!expectedInputTypes.isEmpty()) {
return visitImplicitCastInputTypes(expr, expectedInputTypes);
}
}
return expr;
}
/* ********************************************************************************************
* bind function
* ******************************************************************************************** */
@Override
public Expression visitUnboundFunction(UnboundFunction unboundFunction, CascadesContext context) {
unboundFunction = unboundFunction.withChildren(unboundFunction.children().stream()
.map(e -> e.accept(this, context)).collect(Collectors.toList()));
// bind function
FunctionRegistry functionRegistry = context.getConnectContext().getEnv().getFunctionRegistry();
String functionName = unboundFunction.getName();
List<Object> arguments = unboundFunction.isDistinct()
? ImmutableList.builder()
.add(unboundFunction.isDistinct())
.addAll(unboundFunction.getArguments())
.build()
: (List) unboundFunction.getArguments();
FunctionBuilder builder = functionRegistry.findFunctionBuilder(functionName, arguments);
BoundFunction boundFunction = builder.build(functionName, arguments);
// check
boundFunction.checkLegalityBeforeTypeCoercion();
// type coercion
return visitImplicitCastInputTypes(boundFunction, boundFunction.expectedInputTypes());
}
/**
* gets the method for calculating the time.
* e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB
*/
@Override
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, CascadesContext context) {
Expression left = arithmetic.left().accept(this, context);
Expression right = arithmetic.right().accept(this, context);
// bind function
arithmetic = (TimestampArithmetic) arithmetic.withChildren(left, right);
String funcOpName;
if (arithmetic.getFuncName() == null) {
// e.g. YEARS_ADD, MONTHS_SUB
funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(),
(arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB");
} else {
funcOpName = arithmetic.getFuncName();
}
arithmetic = (TimestampArithmetic) arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT));
// type coercion
return visitImplicitCastInputTypes(arithmetic, arithmetic.expectedInputTypes());
}
/* ********************************************************************************************
* type coercion
* ******************************************************************************************** */
@Override
public Expression visitIntegralDivide(IntegralDivide integralDivide, CascadesContext context) {
Expression left = integralDivide.left().accept(this, context);
Expression right = integralDivide.right().accept(this, context);
// check before bind
integralDivide.checkLegalityBeforeTypeCoercion();
// type coercion
Expression newLeft = TypeCoercionUtils.castIfNotSameType(left, BigIntType.INSTANCE);
Expression newRight = TypeCoercionUtils.castIfNotSameType(right, BigIntType.INSTANCE);
return integralDivide.withChildren(newLeft, newRight);
}
@Override
public Expression visitBinaryOperator(BinaryOperator binaryOperator, CascadesContext context) {
Expression left = binaryOperator.left().accept(this, context);
Expression right = binaryOperator.right().accept(this, context);
// check
binaryOperator.checkLegalityBeforeTypeCoercion();
// characterLiteralTypeCoercion
if (left instanceof Literal || right instanceof Literal) {
if (left instanceof Literal && ((Literal) left).isCharacterLiteral()) {
left = characterLiteralTypeCoercion(((Literal) left).getStringValue(), right.getDataType())
.orElse(left);
}
if (right instanceof Literal && ((Literal) right).isCharacterLiteral()) {
right = characterLiteralTypeCoercion(((Literal) right).getStringValue(), left.getDataType())
.orElse(right);
}
}
binaryOperator = (BinaryOperator) binaryOperator.withChildren(left, right);
// type coercion
if (binaryOperator instanceof ImplicitCastInputTypes) {
List<AbstractDataType> expectedInputTypes = ((ImplicitCastInputTypes) binaryOperator).expectedInputTypes();
if (!expectedInputTypes.isEmpty()) {
binaryOperator.children().stream().filter(e -> e instanceof StringLikeLiteral)
.forEach(expr -> {
try {
new BigDecimal(((StringLikeLiteral) expr).getStringValue());
} catch (NumberFormatException e) {
throw new IllegalStateException(String.format(
"string literal %s cannot be cast to double", expr.toSql()));
}
});
binaryOperator = (BinaryOperator) visitImplicitCastInputTypes(binaryOperator, expectedInputTypes);
}
}
BinaryOperator op = binaryOperator;
Expression opLeft = op.left();
Expression opRight = op.right();
return Optional.of(TypeCoercionUtils.canHandleTypeCoercion(left.getDataType(), right.getDataType()))
.filter(Boolean::booleanValue)
.map(b -> TypeCoercionUtils.findTightestCommonType(op, opLeft.getDataType(), opRight.getDataType()))
.filter(Optional::isPresent)
.map(Optional::get)
.filter(ct -> op.inputType().acceptsType(ct))
.filter(ct -> !opLeft.getDataType().equals(ct) || !opRight.getDataType().equals(ct))
.map(commonType -> {
Expression newLeft = TypeCoercionUtils.castIfNotSameType(opLeft, commonType);
Expression newRight = TypeCoercionUtils.castIfNotSameType(opRight, commonType);
return op.withChildren(newLeft, newRight);
})
.orElse(op.withChildren(opLeft, opRight));
}
@Override
public Expression visitDivide(Divide divide, CascadesContext context) {
Expression left = divide.left().accept(this, context);
Expression right = divide.right().accept(this, context);
divide = (Divide) divide.withChildren(left, right);
// check
divide.checkLegalityBeforeTypeCoercion();
// type coercion
DataType t1 = TypeCoercionUtils.getNumResultType(left.getDataType());
DataType t2 = TypeCoercionUtils.getNumResultType(right.getDataType());
DataType commonType = TypeCoercionUtils.findCommonNumericsType(t1, t2);
if (divide.getLegacyOperator() == Operator.DIVIDE
&& (commonType.isBigIntType() || commonType.isLargeIntType())) {
commonType = DoubleType.INSTANCE;
}
Expression newLeft = TypeCoercionUtils.castIfNotSameType(left, commonType);
Expression newRight = TypeCoercionUtils.castIfNotSameType(right, commonType);
return divide.withChildren(newLeft, newRight);
}
@Override
public Expression visitCaseWhen(CaseWhen caseWhen, CascadesContext context) {
List<Expression> rewrittenChildren = caseWhen.children().stream()
.map(e -> e.accept(this, context)).collect(Collectors.toList());
CaseWhen newCaseWhen = caseWhen.withChildren(rewrittenChildren);
// check
newCaseWhen.checkLegalityBeforeTypeCoercion();
// type coercion
List<DataType> dataTypesForCoercion = newCaseWhen.dataTypesForCoercion();
if (dataTypesForCoercion.size() <= 1) {
return newCaseWhen;
}
DataType first = dataTypesForCoercion.get(0);
if (dataTypesForCoercion.stream().allMatch(dataType -> dataType.equals(first))) {
return newCaseWhen;
}
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(dataTypesForCoercion);
return optionalCommonType
.map(commonType -> {
List<Expression> newChildren
= newCaseWhen.getWhenClauses().stream()
.map(wc -> wc.withChildren(wc.getOperand(),
TypeCoercionUtils.castIfNotSameType(wc.getResult(), commonType)))
.collect(Collectors.toList());
newCaseWhen.getDefaultValue()
.map(dv -> TypeCoercionUtils.castIfNotSameType(dv, commonType))
.ifPresent(newChildren::add);
return newCaseWhen.withChildren(newChildren);
})
.orElse(newCaseWhen);
}
@Override
public Expression visitInPredicate(InPredicate inPredicate, CascadesContext context) {
List<Expression> rewrittenChildren = inPredicate.children().stream()
.map(e -> e.accept(this, context)).collect(Collectors.toList());
InPredicate newInPredicate = inPredicate.withChildren(rewrittenChildren);
// check
newInPredicate.checkLegalityBeforeTypeCoercion();
// type coercion
if (newInPredicate.getOptions().stream().map(Expression::getDataType)
.allMatch(dt -> dt.equals(newInPredicate.getCompareExpr().getDataType()))) {
return newInPredicate;
}
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(newInPredicate.children()
.stream().map(Expression::getDataType).collect(Collectors.toList()));
return optionalCommonType
.map(commonType -> {
List<Expression> newChildren = newInPredicate.children().stream()
.map(e -> TypeCoercionUtils.castIfNotSameType(e, commonType))
.collect(Collectors.toList());
return newInPredicate.withChildren(newChildren);
})
.orElse(newInPredicate);
}
@Override
public Expression visitBitNot(BitNot bitNot, CascadesContext context) {
Expression child = bitNot.child().accept(this, context);
// check
bitNot.checkLegalityBeforeTypeCoercion();
// type coercion
if (child.getDataType().toCatalogDataType().getPrimitiveType().ordinal() > PrimitiveType.LARGEINT.ordinal()) {
child = new Cast(child, BigIntType.INSTANCE);
}
return bitNot.withChildren(child);
}
private Optional<Expression> characterLiteralTypeCoercion(String value, DataType dataType) {
Expression ret = null;
try {
if (dataType instanceof BooleanType) {
if ("true".equalsIgnoreCase(value)) {
ret = BooleanLiteral.TRUE;
}
if ("false".equalsIgnoreCase(value)) {
ret = BooleanLiteral.FALSE;
}
} else if (dataType instanceof IntegralType) {
BigInteger bigInt = new BigInteger(value);
if (BigInteger.valueOf(bigInt.byteValue()).equals(bigInt)) {
ret = new TinyIntLiteral(bigInt.byteValue());
} else if (BigInteger.valueOf(bigInt.shortValue()).equals(bigInt)) {
ret = new SmallIntLiteral(bigInt.shortValue());
} else if (BigInteger.valueOf(bigInt.intValue()).equals(bigInt)) {
ret = new IntegerLiteral(bigInt.intValue());
} else if (BigInteger.valueOf(bigInt.longValue()).equals(bigInt)) {
ret = new BigIntLiteral(bigInt.longValueExact());
} else {
ret = new LargeIntLiteral(bigInt);
}
} else if (dataType instanceof FloatType) {
ret = new FloatLiteral(Float.parseFloat(value));
} else if (dataType instanceof DoubleType) {
ret = new DoubleLiteral(Double.parseDouble(value));
} else if (dataType instanceof DecimalV2Type) {
ret = new DecimalLiteral(new BigDecimal(value));
} else if (dataType instanceof CharType) {
ret = new CharLiteral(value, value.length());
} else if (dataType instanceof VarcharType) {
ret = new VarcharLiteral(value, value.length());
} else if (dataType instanceof StringType) {
ret = new StringLiteral(value);
} else if (dataType instanceof DateType) {
ret = new DateLiteral(value);
} else if (dataType instanceof DateTimeType) {
ret = new DateTimeLiteral(value);
}
} catch (Exception e) {
// ignore
}
return Optional.ofNullable(ret);
}
private Expression visitImplicitCastInputTypes(Expression expr, List<AbstractDataType> expectedInputTypes) {
List<Optional<DataType>> inputImplicitCastTypes
= getInputImplicitCastTypes(expr.children(), expectedInputTypes);
return castInputs(expr, inputImplicitCastTypes);
}
private List<Optional<DataType>> getInputImplicitCastTypes(
List<Expression> inputs, List<AbstractDataType> expectedTypes) {
Builder<Optional<DataType>> implicitCastTypes = ImmutableList.builder();
for (int i = 0; i < inputs.size(); i++) {
DataType argType = inputs.get(i).getDataType();
AbstractDataType expectedType = expectedTypes.get(i);
Optional<DataType> castType = TypeCoercionUtils.implicitCast(argType, expectedType);
// TODO: complete the cast logic like FunctionCallExpr.analyzeImpl
boolean legacyCastCompatible = expectedType instanceof DataType
&& !(expectedType.getClass().equals(NumericType.class))
&& !(expectedType.getClass().equals(IntegralType.class))
&& !(expectedType.getClass().equals(FractionalType.class))
&& !(expectedType.getClass().equals(CharacterType.class))
&& !argType.toCatalogDataType().matchesType(expectedType.toCatalogDataType());
if (!castType.isPresent() && legacyCastCompatible) {
castType = Optional.of((DataType) expectedType);
}
implicitCastTypes.add(castType);
}
return implicitCastTypes.build();
}
private Expression castInputs(Expression expr, List<Optional<DataType>> castTypes) {
return expr.withChildren((child, childIndex) -> {
DataType argType = child.getDataType();
Optional<DataType> castType = castTypes.get(childIndex);
if (castType.isPresent() && !castType.get().equals(argType)) {
return TypeCoercionUtils.castIfNotSameType(child, castType.get());
} else {
return child;
}
});
}
}

View File

@ -29,7 +29,6 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.planner.PlannerContext;
import org.apache.commons.lang.StringUtils;
@ -37,9 +36,9 @@ import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
class Binder extends SubExprAnalyzer {
class SlotBinder extends SubExprAnalyzer {
public Binder(Scope scope, CascadesContext cascadesContext) {
public SlotBinder(Scope scope, CascadesContext cascadesContext) {
super(scope, cascadesContext);
}
@ -48,7 +47,7 @@ class Binder extends SubExprAnalyzer {
}
@Override
public Expression visitUnboundAlias(UnboundAlias unboundAlias, PlannerContext context) {
public Expression visitUnboundAlias(UnboundAlias unboundAlias, CascadesContext context) {
Expression child = unboundAlias.child().accept(this, context);
if (unboundAlias.getAlias().isPresent()) {
return new Alias(child, unboundAlias.getAlias().get());
@ -62,7 +61,7 @@ class Binder extends SubExprAnalyzer {
}
@Override
public Slot visitUnboundSlot(UnboundSlot unboundSlot, PlannerContext context) {
public Slot visitUnboundSlot(UnboundSlot unboundSlot, CascadesContext context) {
Optional<List<Slot>> boundedOpt = Optional.of(bindSlot(unboundSlot, getScope().getSlots()));
boolean foundInThisScope = !boundedOpt.get().isEmpty();
// Currently only looking for symbols on the previous level.
@ -110,7 +109,7 @@ class Binder extends SubExprAnalyzer {
}
@Override
public Expression visitUnboundStar(UnboundStar unboundStar, PlannerContext context) {
public Expression visitUnboundStar(UnboundStar unboundStar, CascadesContext context) {
List<String> qualifier = unboundStar.getQualifier();
List<Slot> slots = getScope().getSlots()
.stream()

View File

@ -33,7 +33,6 @@ import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewri
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.planner.PlannerContext;
import com.google.common.collect.ImmutableList;
@ -45,7 +44,7 @@ import java.util.Optional;
/**
* Use the visitor to iterate sub expression.
*/
public class SubExprAnalyzer extends DefaultExpressionRewriter<PlannerContext> {
class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
private final Scope scope;
private final CascadesContext cascadesContext;
@ -56,7 +55,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter<PlannerContext> {
}
@Override
public Expression visitNot(Not not, PlannerContext context) {
public Expression visitNot(Not not, CascadesContext context) {
Expression child = not.child();
if (child instanceof Exists) {
return visitExistsSubquery(
@ -69,7 +68,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter<PlannerContext> {
}
@Override
public Expression visitExistsSubquery(Exists exists, PlannerContext context) {
public Expression visitExistsSubquery(Exists exists, CascadesContext context) {
AnalyzedResult analyzedResult = analyzeSubquery(exists);
return new Exists(analyzedResult.getLogicalPlan(),
@ -77,7 +76,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter<PlannerContext> {
}
@Override
public Expression visitInSubquery(InSubquery expr, PlannerContext context) {
public Expression visitInSubquery(InSubquery expr, CascadesContext context) {
AnalyzedResult analyzedResult = analyzeSubquery(expr);
checkOutputColumn(analyzedResult.getLogicalPlan());
@ -90,7 +89,7 @@ public class SubExprAnalyzer extends DefaultExpressionRewriter<PlannerContext> {
}
@Override
public Expression visitScalarSubquery(ScalarSubquery scalar, PlannerContext context) {
public Expression visitScalarSubquery(ScalarSubquery scalar, CascadesContext context) {
AnalyzedResult analyzedResult = analyzeSubquery(scalar);
checkOutputColumn(analyzedResult.getLogicalPlan());

View File

@ -45,6 +45,7 @@ public class ExpressionNormalization extends ExpressionRewrite {
BetweenToCompoundRule.INSTANCE,
InPredicateToEqualToRule.INSTANCE,
SimplifyNotExprRule.INSTANCE,
// TODO(morrySnow): remove type coercion from here after we could process subquery type coercion when bind
CharacterLiteralTypeCoercion.INSTANCE,
SimplifyArithmeticRule.INSTANCE,
TypeCoercion.INSTANCE,

View File

@ -57,6 +57,7 @@ import java.util.Optional;
/**
* coercion character literal to another side
*/
@Deprecated
@DependsRules(CheckLegalityBeforeTypeCoercion.class)
public class CharacterLiteralTypeCoercion extends AbstractExpressionRewriteRule {

View File

@ -20,7 +20,6 @@ package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.jobs.batch.CheckLegalityBeforeTypeCoercion;
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
@ -56,7 +55,7 @@ import java.util.stream.Collectors;
* a rule to add implicit cast for expressions.
* This class is inspired by spark's TypeCoercion.
*/
@Developing
@Deprecated
@DependsRules(CheckLegalityBeforeTypeCoercion.class)
public class TypeCoercion extends AbstractExpressionRewriteRule {

View File

@ -21,8 +21,6 @@ import org.apache.doris.catalog.SchemaTable;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.RelationId;
@ -30,8 +28,6 @@ import org.apache.doris.nereids.trees.plans.algebra.Scan;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
@ -59,14 +55,6 @@ public class LogicalSchemaScan extends LogicalRelation implements Scan {
return visitor.visitLogicalSchemaScan(this, context);
}
@Override
public List<Slot> computeNonUserVisibleOutput() {
SchemaTable schemaTable = getTable();
return schemaTable.getBaseSchema().stream()
.map(col -> SlotReference.fromColumn(col, qualified()))
.collect(ImmutableList.toImmutableList());
}
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalSchemaScan(id, table, qualifier, groupExpression, Optional.of(getLogicalProperties()));

View File

@ -314,7 +314,7 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt
ExceptionChecker.expectThrowsWithMsg(
AnalysisException.class,
"Aggregate functions in having clause can't be nested:"
+ " sum(cast((cast(a1 as DOUBLE) + avg(a2)) as SMALLINT)).",
+ " sum((cast(a1 as DOUBLE) + avg(a2))).",
() -> PlanChecker.from(connectContext).analyze(
"SELECT a1 FROM t1 GROUP BY a1 HAVING SUM(a1 + AVG(a2)) > 0"
));
@ -322,7 +322,7 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Patt
ExceptionChecker.expectThrowsWithMsg(
AnalysisException.class,
"Aggregate functions in having clause can't be nested:"
+ " sum(cast((cast((a1 + a2) as DOUBLE) + avg(a2)) as INT)).",
+ " sum((cast((a1 + a2) as DOUBLE) + avg(a2))).",
() -> PlanChecker.from(connectContext).analyze(
"SELECT a1 FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + AVG(a2)) > 0"
));