[enhancement](Nereids) refine and speedup analyzer (#31792) (#32111)

## Proposed changes
1. check data type whether can applied should not throw exception when real data type is subclass of signature data type
2. merge `SlotBinder` and `FunctionBinder` to `ExpressionAnalyzer` to skip rewrite the whole expression tree multiple times.
3. `ExpressionAnalyzer.buildCustomSlotBinderAnalyzer()` provide more refined code to bind slot by different parts and different priority
4. the origin slot binder has O(n^2) complexity, this pr use `Scope.nameToSlot` to support O(n) bind
5. modify some `Collection.stream()` to `ImmutableXxx.builder()` to remove some method call which are difficult to inline by jvm in the hot path, e.g. `Expression.<init>` and `AbstractTreeNode.<init>`
6. modify some `ImmutableXxx.copyOf(xxx)` to `Utils.fastToImmutableList(xxx)` to skip addition copy of the array
7. set init size to `Immutable.builder()` to skip some useless resize
8. lazy compute and cache some heavy operations, like `Scope.nameToSlot` and `CaseWhen.computeDataTypesForCoercion()`

(cherry picked from commit 83c2f5a95827136aac4f0a78c5e841e9a099858c)
This commit is contained in:
924060929
2024-03-12 17:09:38 +08:00
committed by GitHub
parent 5f125bbaaa
commit cf04c9c300
47 changed files with 2097 additions and 964 deletions

View File

@ -98,10 +98,15 @@ public class FunctionRegistry {
String combinatorSuffix = AggCombinerFunctionBuilder.getCombinatorSuffix(name);
functionBuilders = name2InternalBuiltinBuilders.get(nestedName.toLowerCase());
if (functionBuilders != null) {
functionBuilders = functionBuilders.stream()
.map(builder -> new AggCombinerFunctionBuilder(combinatorSuffix, builder))
.filter(functionBuilder -> functionBuilder.canApply(arguments))
.collect(Collectors.toList());
List<FunctionBuilder> candidateBuilders = Lists.newArrayListWithCapacity(functionBuilders.size());
for (FunctionBuilder functionBuilder : functionBuilders) {
AggCombinerFunctionBuilder combinerBuilder
= new AggCombinerFunctionBuilder(combinatorSuffix, functionBuilder);
if (combinerBuilder.canApply(arguments)) {
candidateBuilders.add(combinerBuilder);
}
}
functionBuilders = candidateBuilders;
}
}
}
@ -115,9 +120,12 @@ public class FunctionRegistry {
}
// check the arity and type
List<FunctionBuilder> candidateBuilders = functionBuilders.stream()
.filter(functionBuilder -> functionBuilder.canApply(arguments))
.collect(Collectors.toList());
List<FunctionBuilder> candidateBuilders = Lists.newArrayListWithCapacity(arguments.size());
for (FunctionBuilder functionBuilder : functionBuilders) {
if (functionBuilder.canApply(arguments)) {
candidateBuilders.add(functionBuilder);
}
}
if (candidateBuilders.isEmpty()) {
String candidateHints = getCandidateHint(name, functionBuilders);
throw new AnalysisException("Can not found function '" + qualifiedName

View File

@ -20,6 +20,7 @@ package org.apache.doris.catalog;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.FollowToArgumentType;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableList;
@ -40,7 +41,7 @@ public class FunctionSignature {
private FunctionSignature(DataType returnType, boolean hasVarArgs,
List<? extends DataType> argumentsTypes) {
this.returnType = Objects.requireNonNull(returnType, "returnType is not null");
this.argumentsTypes = ImmutableList.copyOf(
this.argumentsTypes = Utils.fastToImmutableList(
Objects.requireNonNull(argumentsTypes, "argumentsTypes is not null"));
this.hasVarArgs = hasVarArgs;
this.arity = argumentsTypes.size();

View File

@ -72,6 +72,8 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.HashMap;
@ -92,6 +94,7 @@ import javax.annotation.Nullable;
* Context used in memo.
*/
public class CascadesContext implements ScheduleContext {
private static final Logger LOG = LogManager.getLogger(CascadesContext.class);
// in analyze/rewrite stage, the plan will storage in this field
private Plan plan;
@ -713,4 +716,10 @@ public class CascadesContext implements ScheduleContext {
task.run();
}
}
public void printPlanProcess() {
for (PlanProcess row : planProcesses) {
LOG.info("RULE: " + row.ruleName + "\nBEFORE:\n" + row.beforeShape + "\nafter:\n" + row.afterShape);
}
}
}

View File

@ -0,0 +1,22 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.analyzer;
/** ComplexDataType */
public interface ComplexDataType {
}

View File

@ -0,0 +1,115 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.analyzer;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import java.util.List;
/**
* MappingSlot.
* mapping slot to an expression, use to replace slot to this expression.
* this class only use in Scope, and **NEVER** appear in the expression tree
*/
public class MappingSlot extends Slot {
private final Slot slot;
private final Expression mappingExpression;
public MappingSlot(Slot slot, Expression mappingExpression) {
this.slot = slot;
this.mappingExpression = mappingExpression;
}
public Slot getRealSlot() {
return slot;
}
@Override
public List<String> getQualifier() {
return slot.getQualifier();
}
public Expression getMappingExpression() {
return mappingExpression;
}
@Override
public ExprId getExprId() {
return slot.getExprId();
}
@Override
public String getName() throws UnboundException {
return slot.getName();
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitSlot(this, context);
}
@Override
public boolean nullable() {
return slot.nullable();
}
@Override
public String toSql() {
return slot.toSql();
}
@Override
public String toString() {
return slot.toString();
}
@Override
public DataType getDataType() throws UnboundException {
return slot.getDataType();
}
@Override
public String getInternalName() {
return slot.getInternalName();
}
@Override
public Slot withName(String name) {
return this;
}
@Override
public Slot withNullable(boolean newNullable) {
return this;
}
@Override
public Slot withExprId(ExprId exprId) {
return this;
}
@Override
public Slot withQualifier(List<String> qualifier) {
return this;
}
}

View File

@ -19,14 +19,19 @@ package org.apache.doris.nereids.analyzer;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.base.Suppliers;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Sets;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
/**
* The slot range required for expression analyze.
@ -58,16 +63,19 @@ public class Scope {
private final List<Slot> slots;
private final Optional<SubqueryExpr> ownerSubquery;
private final Set<Slot> correlatedSlots;
private final Supplier<ListMultimap<String, Slot>> nameToSlot;
public Scope(Optional<Scope> outerScope, List<Slot> slots, Optional<SubqueryExpr> subqueryExpr) {
this.outerScope = outerScope;
this.slots = ImmutableList.copyOf(Objects.requireNonNull(slots, "slots can not be null"));
this.ownerSubquery = subqueryExpr;
this.correlatedSlots = Sets.newLinkedHashSet();
public Scope(List<? extends Slot> slots) {
this(Optional.empty(), slots, Optional.empty());
}
public Scope(List<Slot> slots) {
this(Optional.empty(), slots, Optional.empty());
/** Scope */
public Scope(Optional<Scope> outerScope, List<? extends Slot> slots, Optional<SubqueryExpr> subqueryExpr) {
this.outerScope = Objects.requireNonNull(outerScope, "outerScope can not be null");
this.slots = Utils.fastToImmutableList(Objects.requireNonNull(slots, "slots can not be null"));
this.ownerSubquery = Objects.requireNonNull(subqueryExpr, "subqueryExpr can not be null");
this.correlatedSlots = Sets.newLinkedHashSet();
this.nameToSlot = Suppliers.memoize(this::buildNameToSlot);
}
public List<Slot> getSlots() {
@ -85,4 +93,17 @@ public class Scope {
public Set<Slot> getCorrelatedSlots() {
return correlatedSlots;
}
/** findSlotIgnoreCase */
public List<Slot> findSlotIgnoreCase(String slotName) {
return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT));
}
private ListMultimap<String, Slot> buildNameToSlot() {
ListMultimap<String, Slot> map = LinkedListMultimap.create(slots.size());
for (Slot slot : slots) {
map.put(slot.getName().toUpperCase(Locale.ROOT), slot);
}
return map;
}
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.rewrite;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.PlanProcess;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
@ -27,8 +28,6 @@ import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
import com.google.common.base.Preconditions;
import java.util.List;
/** PlanTreeRewriteJob */
@ -38,11 +37,9 @@ public abstract class PlanTreeRewriteJob extends Job {
super(type, context);
}
protected RewriteResult rewrite(Plan plan, List<Rule> rules, RewriteJobContext rewriteJobContext) {
// boolean traceEnable = isTraceEnable(context);
boolean isRewriteRoot = rewriteJobContext.isRewriteRoot();
protected final RewriteResult rewrite(Plan plan, List<Rule> rules, RewriteJobContext rewriteJobContext) {
CascadesContext cascadesContext = context.getCascadesContext();
cascadesContext.setIsRewriteRoot(isRewriteRoot);
cascadesContext.setIsRewriteRoot(rewriteJobContext.isRewriteRoot());
boolean showPlanProcess = cascadesContext.showPlanProcess();
for (Rule rule : rules) {
@ -52,8 +49,9 @@ public abstract class PlanTreeRewriteJob extends Job {
Pattern<Plan> pattern = (Pattern<Plan>) rule.getPattern();
if (pattern.matchPlanTree(plan)) {
List<Plan> newPlans = rule.transform(plan, cascadesContext);
Preconditions.checkState(newPlans.size() == 1,
"Rewrite rule should generate one plan: " + rule.getRuleType());
if (newPlans.size() != 1) {
throw new AnalysisException("Rewrite rule should generate one plan: " + rule.getRuleType());
}
Plan newPlan = newPlans.get(0);
if (!newPlan.deepEquals(plan)) {
// don't remove this comment, it can help us to trace some bug when developing.
@ -78,13 +76,13 @@ public abstract class PlanTreeRewriteJob extends Job {
return new RewriteResult(false, plan);
}
protected Plan linkChildrenAndParent(Plan plan, RewriteJobContext rewriteJobContext) {
protected final Plan linkChildrenAndParent(Plan plan, RewriteJobContext rewriteJobContext) {
Plan newPlan = linkChildren(plan, rewriteJobContext.childrenContext);
rewriteJobContext.setResult(newPlan);
return newPlan;
}
protected Plan linkChildren(Plan plan, RewriteJobContext[] childrenContext) {
protected final Plan linkChildren(Plan plan, RewriteJobContext[] childrenContext) {
boolean changed = false;
Plan[] newChildren = new Plan[childrenContext.length];
for (int i = 0; i < childrenContext.length; ++i) {

View File

@ -59,7 +59,7 @@ public class AvgDistinctToSumDivCount extends OneRewriteRuleFactory {
}));
if (!avgToSumDivCount.isEmpty()) {
List<NamedExpression> newOutput = agg.getOutputExpressions().stream()
.map(expr -> (NamedExpression) ExpressionUtils.replace(expr, avgToSumDivCount))
.map(expr -> ExpressionUtils.replaceNameExpression(expr, avgToSumDivCount))
.collect(ImmutableList.toImmutableList());
return new LogicalAggregate<>(agg.getGroupByExpressions(), newOutput, agg.child());
} else {

View File

@ -79,7 +79,7 @@ public class BindSlotWithPaths implements AnalysisRuleFactory {
return ctx.root;
}
newProjectsExpr.addAll(newExprs);
return new LogicalProject(newProjectsExpr, logicalOlapScan.withProjectPulledUp());
return new LogicalProject<>(newProjectsExpr, logicalOlapScan.withProjectPulledUp());
}))
);
}

View File

@ -128,13 +128,11 @@ public class CheckAnalysis implements AnalysisRuleFactory {
}
private void checkExpressionInputTypes(Plan plan) {
final Optional<TypeCheckResult> firstFailed = plan.getExpressions().stream()
.map(Expression::checkInputDataTypes)
.filter(TypeCheckResult::failed)
.findFirst();
if (firstFailed.isPresent()) {
throw new AnalysisException(firstFailed.get().getMessage());
for (Expression expression : plan.getExpressions()) {
TypeCheckResult firstFailed = expression.checkInputDataTypes();
if (firstFailed.failed()) {
throw new AnalysisException(firstFailed.getMessage());
}
}
}

View File

@ -91,6 +91,7 @@ public class EliminateLogicalSelectHint extends OneRewriteRuleFactory {
if (value.isPresent()) {
try {
VariableMgr.setVar(sessionVariable, new SetVar(key, new StringLiteral(value.get())));
context.invalidCache(key);
} catch (Throwable t) {
throw new AnalysisException("Can not set session variable '"
+ key + "' = '" + value.get() + "'", t);
@ -108,7 +109,6 @@ public class EliminateLogicalSelectHint extends OneRewriteRuleFactory {
}
throw new AnalysisException("The nereids is disabled in this sql, fallback to original planner");
}
context.invalidCache(selectHint.getHintName());
}
private void extractLeading(SelectHintLeading selectHint, CascadesContext context,

View File

@ -0,0 +1,807 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.analysis.SetType;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.common.DdlException;
import org.apache.doris.common.util.Util;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundAlias;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.analyzer.UnboundStar;
import org.apache.doris.nereids.analyzer.UnboundVariable;
import org.apache.doris.nereids.analyzer.UnboundVariable.VariableType;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.BitNot;
import org.apache.doris.nereids.trees.expressions.BoundStar;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.ListQuery;
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.Variable;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl;
import org.apache.doris.nereids.trees.expressions.functions.scalar.PushDownToProjectionFunction;
import org.apache.doris.nereids.trees.expressions.functions.udf.AliasUdfBuilder;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.GlobalVariable;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.qe.VariableMgr;
import org.apache.doris.qe.VariableVarConverters;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.stream.Collectors;
/** ExpressionAnalyzer */
public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext> {
private final Plan currentPlan;
/*
bounded={table.a, a}
unbound=a
if enableExactMatch, 'a' is bound to bounded 'a',
if not enableExactMatch, 'a' is ambiguous
in order to be compatible to original planner,
exact match mode is not enabled for having clause
but enabled for order by clause
TODO after remove original planner, always enable exact match mode.
*/
private final boolean enableExactMatch;
private final boolean bindSlotInOuterScope;
private boolean currentInLambda;
public ExpressionAnalyzer(Plan currentPlan, Scope scope, CascadesContext cascadesContext,
boolean enableExactMatch, boolean bindSlotInOuterScope) {
super(scope, cascadesContext);
this.currentPlan = currentPlan;
this.enableExactMatch = enableExactMatch;
this.bindSlotInOuterScope = bindSlotInOuterScope;
}
public Expression analyze(Expression expression, ExpressionRewriteContext context) {
return expression.accept(this, context);
}
@Override
public Expression visit(Expression expr, ExpressionRewriteContext context) {
expr = super.visit(expr, context);
expr.checkLegalityBeforeTypeCoercion();
// this cannot be removed, because some function already construct in parser.
if (expr instanceof ImplicitCastInputTypes) {
List<DataType> expectedInputTypes = ((ImplicitCastInputTypes) expr).expectedInputTypes();
if (!expectedInputTypes.isEmpty()) {
return TypeCoercionUtils.implicitCastInputTypes(expr, expectedInputTypes);
}
}
return expr;
}
@Override
public Expression visitLambda(Lambda lambda, ExpressionRewriteContext context) {
boolean originInLambda = currentInLambda;
try {
currentInLambda = true;
return super.visitLambda(lambda, context);
} finally {
currentInLambda = originInLambda;
}
}
/* ********************************************************************************************
* bind slot
* ******************************************************************************************** */
@Override
public Expression visitUnboundVariable(UnboundVariable unboundVariable, ExpressionRewriteContext context) {
String name = unboundVariable.getName();
SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
Literal literal = null;
if (unboundVariable.getType() == VariableType.DEFAULT) {
literal = VariableMgr.getLiteral(sessionVariable, name, SetType.DEFAULT);
} else if (unboundVariable.getType() == VariableType.SESSION) {
literal = VariableMgr.getLiteral(sessionVariable, name, SetType.SESSION);
} else if (unboundVariable.getType() == VariableType.GLOBAL) {
literal = VariableMgr.getLiteral(sessionVariable, name, SetType.GLOBAL);
} else if (unboundVariable.getType() == VariableType.USER) {
literal = ConnectContext.get().getLiteralForUserVar(name);
}
if (literal == null) {
throw new AnalysisException("Unsupported system variable: " + unboundVariable.getName());
}
if (!Strings.isNullOrEmpty(name) && VariableVarConverters.hasConverter(name)) {
try {
Preconditions.checkArgument(literal instanceof IntegerLikeLiteral);
IntegerLikeLiteral integerLikeLiteral = (IntegerLikeLiteral) literal;
literal = new StringLiteral(VariableVarConverters.decode(name, integerLikeLiteral.getLongValue()));
} catch (DdlException e) {
throw new AnalysisException(e.getMessage());
}
}
return new Variable(unboundVariable.getName(), unboundVariable.getType(), literal);
}
@Override
public Expression visitUnboundAlias(UnboundAlias unboundAlias, ExpressionRewriteContext context) {
Expression child = unboundAlias.child().accept(this, context);
if (unboundAlias.getAlias().isPresent()) {
return new Alias(child, unboundAlias.getAlias().get());
// TODO: the variant bind element_at(slot, 'name') will return a slot, and we should
// assign an Alias to this function, this is trick and should refactor it
} else if (!(unboundAlias.child() instanceof ElementAt) && child instanceof NamedExpression) {
return new Alias(child, ((NamedExpression) child).getName());
} else {
return new Alias(child);
}
}
@Override
public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteContext context) {
Optional<Scope> outerScope = getScope().getOuterScope();
Optional<List<? extends Expression>> boundedOpt = Optional.of(bindSlotByThisScope(unboundSlot));
boolean foundInThisScope = !boundedOpt.get().isEmpty();
// Currently only looking for symbols on the previous level.
if (bindSlotInOuterScope && !foundInThisScope && outerScope.isPresent()) {
boundedOpt = Optional.of(bindSlotByScope(unboundSlot, outerScope.get()));
}
List<? extends Expression> bounded = boundedOpt.get();
switch (bounded.size()) {
case 0:
if (!currentInLambda) {
String tableName = StringUtils.join(unboundSlot.getQualifier(), ".");
if (tableName.isEmpty()) {
tableName = "table list";
}
throw new AnalysisException("Unknown column '"
+ unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1)
+ "' in '" + tableName + "' in "
+ currentPlan.getType().toString().substring("LOGICAL_".length()) + " clause");
}
return unboundSlot;
case 1:
Expression firstBound = bounded.get(0);
if (!foundInThisScope && firstBound instanceof Slot
&& !outerScope.get().getCorrelatedSlots().contains(firstBound)) {
outerScope.get().getCorrelatedSlots().add((Slot) firstBound);
}
return firstBound;
default:
if (enableExactMatch) {
// select t1.k k, t2.k
// from t1 join t2 order by k
//
// 't1.k k' is denoted by alias_k, its full name is 'k'
// 'order by k' is denoted as order_k, it full name is 'k'
// 't2.k' in select list, its full name is 't2.k'
//
// order_k can be bound on alias_k and t2.k
// alias_k is exactly matched, since its full name is exactly match full name of order_k
// t2.k is not exactly matched, since t2.k's full name is larger than order_k
List<Slot> exactMatch = bounded.stream()
.filter(Slot.class::isInstance)
.map(Slot.class::cast)
.filter(bound -> unboundSlot.getNameParts().size() == bound.getQualifier().size() + 1)
.collect(Collectors.toList());
if (exactMatch.size() == 1) {
return exactMatch.get(0);
}
}
throw new AnalysisException(String.format("%s is ambiguous: %s.",
unboundSlot.toSql(),
bounded.stream()
.map(Expression::toString)
.collect(Collectors.joining(", "))));
}
}
@Override
public Expression visitUnboundStar(UnboundStar unboundStar, ExpressionRewriteContext context) {
List<String> qualifier = unboundStar.getQualifier();
boolean showHidden = Util.showHiddenColumns();
List<Slot> slots = getScope().getSlots()
.stream()
.filter(slot -> !(slot instanceof SlotReference)
|| (((SlotReference) slot).isVisible()) || showHidden)
.filter(slot -> !(((SlotReference) slot).hasSubColPath()))
.collect(Collectors.toList());
switch (qualifier.size()) {
case 0: // select *
return new BoundStar(slots);
case 1: // select table.*
case 2: // select db.table.*
case 3: // select catalog.db.table.*
return bindQualifiedStar(qualifier, slots);
default:
throw new AnalysisException("Not supported qualifier: "
+ StringUtils.join(qualifier, "."));
}
}
/* ********************************************************************************************
* bind function
* ******************************************************************************************** */
@Override
public Expression visitUnboundFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) {
if (unboundFunction.isHighOrder()) {
unboundFunction = bindHighOrderFunction(unboundFunction, context);
} else {
unboundFunction = (UnboundFunction) rewriteChildren(this, unboundFunction, context);
}
// bind function
FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
List<Object> arguments = unboundFunction.isDistinct()
? ImmutableList.builderWithExpectedSize(unboundFunction.arity() + 1)
.add(unboundFunction.isDistinct())
.addAll(unboundFunction.getArguments())
.build()
: (List) unboundFunction.getArguments();
if (StringUtils.isEmpty(unboundFunction.getDbName())) {
// we will change arithmetic function like add(), subtract(), bitnot()
// to the corresponding objects rather than BoundFunction.
ArithmeticFunctionBinder functionBinder = new ArithmeticFunctionBinder();
if (functionBinder.isBinaryArithmetic(unboundFunction.getName())) {
return functionBinder.bindBinaryArithmetic(unboundFunction.getName(), unboundFunction.children())
.accept(this, context);
}
}
String functionName = unboundFunction.getName();
FunctionBuilder builder = functionRegistry.findFunctionBuilder(
unboundFunction.getDbName(), functionName, arguments);
if (builder instanceof AliasUdfBuilder) {
// we do type coercion in build function in alias function, so it's ok to return directly.
return builder.build(functionName, arguments);
} else {
Expression boundFunction = TypeCoercionUtils
.processBoundFunction((BoundFunction) builder.build(functionName, arguments));
if (boundFunction instanceof Count
&& context.cascadesContext.getOuterScope().isPresent()
&& !context.cascadesContext.getOuterScope().get().getCorrelatedSlots()
.isEmpty()) {
// consider sql: SELECT * FROM t1 WHERE t1.a <= (SELECT COUNT(t2.a) FROM t2 WHERE (t1.b = t2.b));
// when unnest correlated subquery, we create a left join node.
// outer query is left table and subquery is right one
// if there is no match, the row from right table is filled with nulls
// but COUNT function is always not nullable.
// so wrap COUNT with Nvl to ensure it's result is 0 instead of null to get the correct result
boundFunction = new Nvl(boundFunction, new BigIntLiteral(0));
}
return boundFunction;
}
}
@Override
public Expression visitBoundFunction(BoundFunction boundFunction, ExpressionRewriteContext context) {
boundFunction = (BoundFunction) super.visitBoundFunction(boundFunction, context);
return TypeCoercionUtils.processBoundFunction(boundFunction);
}
@Override
public Expression visitElementAt(ElementAt elementAt, ExpressionRewriteContext context) {
ElementAt boundFunction = (ElementAt) visitBoundFunction(elementAt, context);
if (PushDownToProjectionFunction.validToPushDown(boundFunction)) {
if (ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable() != null
&& !ConnectContext.get().getSessionVariable().isEnableRewriteElementAtToSlot()) {
return boundFunction;
}
Slot slot = boundFunction.getInputSlots().stream().findFirst().get();
if (slot.hasUnbound()) {
slot = (Slot) slot.accept(this, context);
}
StatementContext statementContext = context.cascadesContext.getStatementContext();
Expression originBoundFunction = boundFunction.rewriteUp(expr -> {
if (expr instanceof SlotReference) {
Expression originalExpr = statementContext.getOriginalExpr((SlotReference) expr);
return originalExpr == null ? expr : originalExpr;
}
return expr;
});
// rewrite to slot and bound this slot
return PushDownToProjectionFunction.rewriteToSlot(
(PushDownToProjectionFunction) originBoundFunction, (SlotReference) slot);
}
return boundFunction;
}
/**
* gets the method for calculating the time.
* e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB
*/
@Override
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, ExpressionRewriteContext context) {
Expression left = arithmetic.left().accept(this, context);
Expression right = arithmetic.right().accept(this, context);
arithmetic = (TimestampArithmetic) arithmetic.withChildren(left, right);
// bind function
String funcOpName;
if (arithmetic.getFuncName() == null) {
// e.g. YEARS_ADD, MONTHS_SUB
funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(),
(arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB");
} else {
funcOpName = arithmetic.getFuncName();
}
arithmetic = (TimestampArithmetic) arithmetic.withFuncName(funcOpName.toLowerCase(Locale.ROOT));
// type coercion
return TypeCoercionUtils.processTimestampArithmetic(arithmetic);
}
/* ********************************************************************************************
* type coercion
* ******************************************************************************************** */
@Override
public Expression visitBitNot(BitNot bitNot, ExpressionRewriteContext context) {
Expression child = bitNot.child().accept(this, context);
// type coercion
if (!(child.getDataType().isIntegralType() || child.getDataType().isBooleanType())) {
child = new Cast(child, BigIntType.INSTANCE);
}
return bitNot.withChildren(child);
}
@Override
public Expression visitDivide(Divide divide, ExpressionRewriteContext context) {
Expression left = divide.left().accept(this, context);
Expression right = divide.right().accept(this, context);
divide = (Divide) divide.withChildren(left, right);
// type coercion
return TypeCoercionUtils.processDivide(divide);
}
@Override
public Expression visitIntegralDivide(IntegralDivide integralDivide, ExpressionRewriteContext context) {
Expression left = integralDivide.left().accept(this, context);
Expression right = integralDivide.right().accept(this, context);
integralDivide = (IntegralDivide) integralDivide.withChildren(left, right);
// type coercion
return TypeCoercionUtils.processIntegralDivide(integralDivide);
}
@Override
public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, ExpressionRewriteContext context) {
Expression left = binaryArithmetic.left().accept(this, context);
Expression right = binaryArithmetic.right().accept(this, context);
binaryArithmetic = (BinaryArithmetic) binaryArithmetic.withChildren(left, right);
return TypeCoercionUtils.processBinaryArithmetic(binaryArithmetic);
}
@Override
public Expression visitCompoundPredicate(CompoundPredicate compoundPredicate, ExpressionRewriteContext context) {
Expression left = compoundPredicate.left().accept(this, context);
Expression right = compoundPredicate.right().accept(this, context);
CompoundPredicate ret = (CompoundPredicate) compoundPredicate.withChildren(left, right);
return TypeCoercionUtils.processCompoundPredicate(ret);
}
@Override
public Expression visitNot(Not not, ExpressionRewriteContext context) {
// maybe is `not subquery`, we should bind it first
Expression expr = super.visitNot(not, context);
// expression is not subquery
if (expr instanceof Not) {
Expression child = not.child().accept(this, context);
Expression newChild = TypeCoercionUtils.castIfNotSameType(child, BooleanType.INSTANCE);
if (child != newChild) {
return expr.withChildren(newChild);
}
}
return expr;
}
@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) {
Expression left = cp.left().accept(this, context);
Expression right = cp.right().accept(this, context);
cp = (ComparisonPredicate) cp.withChildren(left, right);
return TypeCoercionUtils.processComparisonPredicate(cp);
}
@Override
public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) {
Builder<Expression> rewrittenChildren = ImmutableList.builderWithExpectedSize(caseWhen.arity());
for (Expression child : caseWhen.children()) {
rewrittenChildren.add(child.accept(this, context));
}
CaseWhen newCaseWhen = caseWhen.withChildren(rewrittenChildren.build());
newCaseWhen.checkLegalityBeforeTypeCoercion();
return TypeCoercionUtils.processCaseWhen(newCaseWhen);
}
@Override
public Expression visitWhenClause(WhenClause whenClause, ExpressionRewriteContext context) {
return whenClause.withChildren(TypeCoercionUtils.castIfNotSameType(
whenClause.getOperand().accept(this, context), BooleanType.INSTANCE),
whenClause.getResult().accept(this, context));
}
@Override
public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
List<Expression> rewrittenChildren = inPredicate.children().stream()
.map(e -> e.accept(this, context)).collect(Collectors.toList());
InPredicate newInPredicate = inPredicate.withChildren(rewrittenChildren);
return TypeCoercionUtils.processInPredicate(newInPredicate);
}
@Override
public Expression visitInSubquery(InSubquery inSubquery, ExpressionRewriteContext context) {
// analyze subquery
inSubquery = (InSubquery) super.visitInSubquery(inSubquery, context);
// compareExpr already analyze when invoke super.visitInSubquery
Expression newCompareExpr = inSubquery.getCompareExpr();
// but ListQuery does not analyze
Expression newListQuery = inSubquery.getListQuery().accept(this, context);
ComparisonPredicate afterTypeCoercion = (ComparisonPredicate) TypeCoercionUtils.processComparisonPredicate(
new EqualTo(newCompareExpr, newListQuery));
if (newListQuery.getDataType().isBitmapType()) {
if (!newCompareExpr.getDataType().isBigIntType()) {
newCompareExpr = new Cast(newCompareExpr, BigIntType.INSTANCE);
}
} else {
newCompareExpr = afterTypeCoercion.left();
}
return new InSubquery(newCompareExpr, (ListQuery) afterTypeCoercion.right(),
inSubquery.getCorrelateSlots(), ((ListQuery) afterTypeCoercion.right()).getTypeCoercionExpr(),
inSubquery.isNot());
}
@Override
public Expression visitMatch(Match match, ExpressionRewriteContext context) {
Expression left = match.left().accept(this, context);
Expression right = match.right().accept(this, context);
// check child type
if (!left.getDataType().isStringLikeType()
&& !(left.getDataType() instanceof ArrayType
&& ((ArrayType) left.getDataType()).getItemType().isStringLikeType())
&& !left.getDataType().isVariantType()) {
throw new AnalysisException(String.format(
"left operand '%s' part of predicate "
+ "'%s' should return type 'STRING', 'ARRAY<STRING> or VARIANT' but "
+ "returns type '%s'.",
left.toSql(), match.toSql(), left.getDataType()));
}
if (!right.getDataType().isStringLikeType() && !right.getDataType().isNullType()) {
throw new AnalysisException(String.format(
"right operand '%s' part of predicate " + "'%s' should return type 'STRING' but "
+ "returns type '%s'.",
right.toSql(), match.toSql(), right.getDataType()));
}
if (left.getDataType().isVariantType()) {
left = new Cast(left, right.getDataType());
}
return match.withChildren(left, right);
}
@Override
public Expression visitCast(Cast cast, ExpressionRewriteContext context) {
cast = (Cast) super.visitCast(cast, context);
// NOTICE: just for compatibility with legacy planner.
if (cast.child().getDataType().isComplexType() || cast.getDataType().isComplexType()) {
TypeCoercionUtils.checkCanCastTo(cast.child().getDataType(), cast.getDataType());
}
return cast;
}
private BoundStar bindQualifiedStar(List<String> qualifierStar, List<Slot> boundSlots) {
// FIXME: compatible with previous behavior:
// https://github.com/apache/doris/pull/10415/files/3fe9cb0c3f805ab3a9678033b281b16ad93ec60a#r910239452
List<Slot> slots = boundSlots.stream().filter(boundSlot -> {
switch (qualifierStar.size()) {
// table.*
case 1:
List<String> boundSlotQualifier = boundSlot.getQualifier();
switch (boundSlotQualifier.size()) {
// bound slot is `column` and no qualified
case 0:
return false;
case 1: // bound slot is `table`.`column`
return qualifierStar.get(0).equalsIgnoreCase(boundSlotQualifier.get(0));
case 2:// bound slot is `db`.`table`.`column`
return qualifierStar.get(0).equalsIgnoreCase(boundSlotQualifier.get(1));
case 3:// bound slot is `catalog`.`db`.`table`.`column`
return qualifierStar.get(0).equalsIgnoreCase(boundSlotQualifier.get(2));
default:
throw new AnalysisException("Not supported qualifier: "
+ StringUtils.join(qualifierStar, "."));
}
case 2: // db.table.*
boundSlotQualifier = boundSlot.getQualifier();
switch (boundSlotQualifier.size()) {
// bound slot is `column` and no qualified
case 0:
case 1: // bound slot is `table`.`column`
return false;
case 2:// bound slot is `db`.`table`.`column`
return compareDbName(qualifierStar.get(0), boundSlotQualifier.get(0))
&& qualifierStar.get(1).equalsIgnoreCase(boundSlotQualifier.get(1));
case 3:// bound slot is `catalog`.`db`.`table`.`column`
return compareDbName(qualifierStar.get(0), boundSlotQualifier.get(1))
&& qualifierStar.get(1).equalsIgnoreCase(boundSlotQualifier.get(2));
default:
throw new AnalysisException("Not supported qualifier: "
+ StringUtils.join(qualifierStar, ".") + ".*");
}
case 3: // catalog.db.table.*
boundSlotQualifier = boundSlot.getQualifier();
switch (boundSlotQualifier.size()) {
// bound slot is `column` and no qualified
case 0:
case 1: // bound slot is `table`.`column`
case 2: // bound slot is `db`.`table`.`column`
return false;
case 3:// bound slot is `catalog`.`db`.`table`.`column`
return qualifierStar.get(0).equalsIgnoreCase(boundSlotQualifier.get(0))
&& compareDbName(qualifierStar.get(1), boundSlotQualifier.get(1))
&& qualifierStar.get(2).equalsIgnoreCase(boundSlotQualifier.get(2));
default:
throw new AnalysisException("Not supported qualifier: "
+ StringUtils.join(qualifierStar, ".") + ".*");
}
default:
throw new AnalysisException("Not supported name: "
+ StringUtils.join(qualifierStar, ".") + ".*");
}
}).collect(Collectors.toList());
if (slots.isEmpty()) {
throw new AnalysisException("unknown qualifier: " + StringUtils.join(qualifierStar, ".") + ".*");
}
return new BoundStar(slots);
}
protected List<? extends Expression> bindSlotByThisScope(UnboundSlot unboundSlot) {
return bindSlotByScope(unboundSlot, getScope());
}
protected List<Slot> bindExactSlotsByThisScope(UnboundSlot unboundSlot, Scope scope) {
List<Slot> candidates = bindSlotByScope(unboundSlot, scope);
if (candidates.size() == 1) {
return candidates;
}
List<Slot> extractSlots = Utils.filterImmutableList(candidates, bound ->
unboundSlot.getNameParts().size() == bound.getQualifier().size() + 1
);
// we should return origin candidates slots if extract slots is empty,
// and then throw an ambiguous exception
return !extractSlots.isEmpty() ? extractSlots : candidates;
}
/** bindSlotByScope */
public List<Slot> bindSlotByScope(UnboundSlot unboundSlot, Scope scope) {
List<String> nameParts = unboundSlot.getNameParts();
int namePartSize = nameParts.size();
switch (namePartSize) {
// column
case 1: {
return bindSingleSlotByName(nameParts.get(0), scope);
}
// table.column
case 2: {
return bindSingleSlotByTable(nameParts.get(0), nameParts.get(1), scope);
}
// db.table.column
case 3: {
return bindSingleSlotByDb(nameParts.get(0), nameParts.get(1), nameParts.get(2), scope);
}
// catalog.db.table.column
case 4: {
return bindSingleSlotByCatalog(
nameParts.get(0), nameParts.get(1), nameParts.get(2), nameParts.get(3), scope);
}
default: {
throw new AnalysisException("Not supported name: " + StringUtils.join(nameParts, "."));
}
}
}
public static boolean compareDbName(String boundedDbName, String unBoundDbName) {
return unBoundDbName.equalsIgnoreCase(boundedDbName);
}
public static boolean sameTableName(String boundSlot, String unboundSlot) {
if (GlobalVariable.lowerCaseTableNames != 1) {
return boundSlot.equals(unboundSlot);
} else {
return boundSlot.equalsIgnoreCase(unboundSlot);
}
}
private void checkBoundLambda(Expression lambdaFunction, List<String> argumentNames) {
lambdaFunction.foreachUp(e -> {
if (e instanceof UnboundSlot) {
UnboundSlot unboundSlot = (UnboundSlot) e;
throw new AnalysisException("Unknown lambda slot '"
+ unboundSlot.getNameParts().get(unboundSlot.getNameParts().size() - 1)
+ " in lambda arguments" + argumentNames);
}
});
}
private UnboundFunction bindHighOrderFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) {
int childrenSize = unboundFunction.children().size();
List<Expression> subChildren = new ArrayList<>();
for (int i = 1; i < childrenSize; i++) {
subChildren.add(unboundFunction.child(i).accept(this, context));
}
// bindLambdaFunction
Lambda lambda = (Lambda) unboundFunction.children().get(0);
Expression lambdaFunction = lambda.getLambdaFunction();
List<ArrayItemReference> arrayItemReferences = lambda.makeArguments(subChildren);
// 1.bindSlot
List<Slot> boundedSlots = arrayItemReferences.stream()
.map(ArrayItemReference::toSlot)
.collect(ImmutableList.toImmutableList());
lambdaFunction = new SlotBinder(new Scope(boundedSlots), context.cascadesContext,
true, false).bind(lambdaFunction);
checkBoundLambda(lambdaFunction, lambda.getLambdaArgumentNames());
// 2.bindFunction
lambdaFunction = lambdaFunction.accept(this, context);
Lambda lambdaClosure = lambda.withLambdaFunctionArguments(lambdaFunction, arrayItemReferences);
// We don't add the ArrayExpression in high order function at all
return unboundFunction.withChildren(ImmutableList.<Expression>builder()
.add(lambdaClosure)
.build());
}
private boolean shouldBindSlotBy(int namePartSize, Slot boundSlot) {
if (boundSlot instanceof SlotReference
&& ((SlotReference) boundSlot).hasSubColPath()) {
// already bounded
return false;
}
if (namePartSize > boundSlot.getQualifier().size() + 1) {
return false;
}
return true;
}
private List<Slot> bindSingleSlotByName(String name, Scope scope) {
int namePartSize = 1;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
// set sql case as alias
usedSlots.add(boundSlot.withName(name));
}
return usedSlots.build();
}
private List<Slot> bindSingleSlotByTable(String table, String name, Scope scope) {
int namePartSize = 2;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
List<String> boundSlotQualifier = boundSlot.getQualifier();
String boundSlotTable = boundSlotQualifier.get(boundSlotQualifier.size() - 1);
if (!sameTableName(boundSlotTable, table)) {
continue;
}
// set sql case as alias
usedSlots.add(boundSlot.withName(name));
}
return usedSlots.build();
}
private List<Slot> bindSingleSlotByDb(String db, String table, String name, Scope scope) {
int namePartSize = 3;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
List<String> boundSlotQualifier = boundSlot.getQualifier();
String boundSlotDb = boundSlotQualifier.get(boundSlotQualifier.size() - 2);
String boundSlotTable = boundSlotQualifier.get(boundSlotQualifier.size() - 1);
if (!compareDbName(boundSlotDb, db) || !sameTableName(boundSlotTable, table)) {
continue;
}
// set sql case as alias
usedSlots.add(boundSlot.withName(name));
}
return usedSlots.build();
}
private List<Slot> bindSingleSlotByCatalog(String catalog, String db, String table, String name, Scope scope) {
int namePartSize = 4;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
List<String> boundSlotQualifier = boundSlot.getQualifier();
String boundSlotCatalog = boundSlotQualifier.get(boundSlotQualifier.size() - 3);
String boundSlotDb = boundSlotQualifier.get(boundSlotQualifier.size() - 2);
String boundSlotTable = boundSlotQualifier.get(boundSlotQualifier.size() - 1);
if (!boundSlotCatalog.equalsIgnoreCase(catalog)
|| !compareDbName(boundSlotDb, db)
|| !sameTableName(boundSlotTable, table)) {
continue;
}
// set sql case as alias
usedSlots.add(boundSlot.withName(name));
}
return usedSlots.build();
}
}

View File

@ -55,7 +55,7 @@ import java.util.stream.Collectors;
/**
* SlotBinder is used to bind slot
*/
public class SlotBinder extends SubExprAnalyzer {
public class SlotBinder extends SubExprAnalyzer<CascadesContext> {
/*
bounded={table.a, a}
unbound=a

View File

@ -20,7 +20,6 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
@ -49,8 +48,7 @@ import java.util.Optional;
/**
* Use the visitor to iterate sub expression.
*/
class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
class SubExprAnalyzer<T> extends DefaultExpressionRewriter<T> {
private final Scope scope;
private final CascadesContext cascadesContext;
@ -60,7 +58,7 @@ class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
}
@Override
public Expression visitNot(Not not, CascadesContext context) {
public Expression visitNot(Not not, T context) {
Expression child = not.child();
if (child instanceof Exists) {
return visitExistsSubquery(
@ -73,7 +71,7 @@ class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
}
@Override
public Expression visitExistsSubquery(Exists exists, CascadesContext context) {
public Expression visitExistsSubquery(Exists exists, T context) {
AnalyzedResult analyzedResult = analyzeSubquery(exists);
if (analyzedResult.rootIsLimitZero()) {
return BooleanLiteral.of(exists.isNot());
@ -87,7 +85,7 @@ class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
}
@Override
public Expression visitInSubquery(InSubquery expr, CascadesContext context) {
public Expression visitInSubquery(InSubquery expr, T context) {
AnalyzedResult analyzedResult = analyzeSubquery(expr);
checkOutputColumn(analyzedResult.getLogicalPlan());
@ -101,7 +99,7 @@ class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
}
@Override
public Expression visitScalarSubquery(ScalarSubquery scalar, CascadesContext context) {
public Expression visitScalarSubquery(ScalarSubquery scalar, T context) {
AnalyzedResult analyzedResult = analyzeSubquery(scalar);
checkOutputColumn(analyzedResult.getLogicalPlan());
@ -111,13 +109,6 @@ class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
return new ScalarSubquery(analyzedResult.getLogicalPlan(), analyzedResult.getCorrelatedSlots());
}
private boolean childrenAtLeastOneInOrExistsSub(BinaryOperator binaryOperator) {
return binaryOperator.left().anyMatch(InSubquery.class::isInstance)
|| binaryOperator.left().anyMatch(Exists.class::isInstance)
|| binaryOperator.right().anyMatch(InSubquery.class::isInstance)
|| binaryOperator.right().anyMatch(Exists.class::isInstance);
}
private void checkOutputColumn(LogicalPlan plan) {
if (plan.getOutput().size() != 1) {
throw new AnalysisException("Multiple columns returned by subquery are not yet supported. Found "

View File

@ -79,7 +79,7 @@ public class PushDownLimitDistinctThroughUnion implements RewriteRuleFactory {
.map(expr -> ExpressionUtils.replace(expr, replaceMap))
.collect(Collectors.toList());
List<NamedExpression> newOutputs = agg.getOutputs().stream()
.map(expr -> ExpressionUtils.replace(expr, replaceMap))
.map(expr -> ExpressionUtils.replaceNameExpression(expr, replaceMap))
.collect(Collectors.toList());
LogicalAggregate<Plan> newAgg = new LogicalAggregate<>(newGroupBy, newOutputs, child);

View File

@ -62,7 +62,7 @@ public class PushProjectIntoUnion extends OneRewriteRuleFactory {
if (old instanceof SlotReference) {
newProjections.add(replaceRootMap.get(old));
} else {
newProjections.add(ExpressionUtils.replace(old, replaceMap));
newProjections.add(ExpressionUtils.replaceNameExpression(old, replaceMap));
}
}
newConstExprs.add(newProjections.build());

View File

@ -65,7 +65,7 @@ public class PushProjectThroughUnion extends OneRewriteRuleFactory {
replaceMap.put(union.getOutput().get(j), union.getRegularChildOutput(i).get(j));
}
List<NamedExpression> childProjections = project.getProjects().stream()
.map(e -> (NamedExpression) ExpressionUtils.replace(e, replaceMap))
.map(e -> (NamedExpression) ExpressionUtils.replaceNameExpression(e, replaceMap))
.map(e -> {
if (e instanceof Alias) {
return new Alias(((Alias) e).child(), e.getName());

View File

@ -1572,7 +1572,7 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
LogicalProject<? extends Plan> project,
Map<Expression, Expression> projectMap) {
return project.getProjects().stream()
.map(expr -> (NamedExpression) ExpressionUtils.replace(expr, projectMap))
.map(expr -> (NamedExpression) ExpressionUtils.replaceNameExpression(expr, projectMap))
.collect(Collectors.toList());
}

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.trees;
import com.google.common.collect.ImmutableList;
import org.apache.doris.nereids.util.Utils;
import java.util.List;
@ -34,11 +34,15 @@ public abstract class AbstractTreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>>
// https://github.com/apache/doris/pull/9807#discussion_r884829067
protected AbstractTreeNode(NODE_TYPE... children) {
this.children = ImmutableList.copyOf(children);
// NOTE: ImmutableList.copyOf has additional clone of the list, so here we
// direct generate a ImmutableList
this.children = Utils.fastToImmutableList(children);
}
protected AbstractTreeNode(List<NODE_TYPE> children) {
this.children = ImmutableList.copyOf(children);
// NOTE: ImmutableList.copyOf has additional clone of the list, so here we
// direct generate a ImmutableList
this.children = Utils.fastToImmutableList(children);
}
@Override

View File

@ -17,6 +17,8 @@
package org.apache.doris.nereids.trees;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
@ -24,6 +26,7 @@ import com.google.common.collect.ImmutableSet;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.function.BiFunction;
@ -46,7 +49,7 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> {
int arity();
default NODE_TYPE withChildren(NODE_TYPE... children) {
return withChildren(ImmutableList.copyOf(children));
return withChildren(Utils.fastToImmutableList(children));
}
NODE_TYPE withChildren(List<NODE_TYPE> children);
@ -175,6 +178,18 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> {
}
}
/** foreachBreath */
default void foreachBreath(Predicate<TreeNode<NODE_TYPE>> func) {
LinkedList<TreeNode<NODE_TYPE>> queue = new LinkedList<>();
queue.add(this);
while (!queue.isEmpty()) {
TreeNode<NODE_TYPE> current = queue.pollFirst();
if (!func.test(current)) {
queue.addAll(current.children());
}
}
}
default void foreachUp(Consumer<TreeNode<NODE_TYPE>> func) {
for (NODE_TYPE child : children()) {
child.foreach(func);

View File

@ -19,18 +19,19 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;
import java.util.function.Supplier;
/**
* The internal representation of
@ -43,17 +44,24 @@ public class CaseWhen extends Expression {
private final List<WhenClause> whenClauses;
private final Optional<Expression> defaultValue;
private Supplier<List<DataType>> dataTypesForCoercion;
public CaseWhen(List<WhenClause> whenClauses) {
super((List) whenClauses);
this.whenClauses = ImmutableList.copyOf(Objects.requireNonNull(whenClauses));
defaultValue = Optional.empty();
this.dataTypesForCoercion = computeDataTypesForCoercion();
}
/** CaseWhen */
public CaseWhen(List<WhenClause> whenClauses, Expression defaultValue) {
super(ImmutableList.<Expression>builder().addAll(whenClauses).add(defaultValue).build());
super(ImmutableList.<Expression>builderWithExpectedSize(whenClauses.size() + 1)
.addAll(whenClauses)
.add(defaultValue)
.build());
this.whenClauses = ImmutableList.copyOf(Objects.requireNonNull(whenClauses));
this.defaultValue = Optional.of(Objects.requireNonNull(defaultValue));
this.dataTypesForCoercion = computeDataTypesForCoercion();
}
public List<WhenClause> getWhenClauses() {
@ -64,10 +72,9 @@ public class CaseWhen extends Expression {
return defaultValue;
}
/** dataTypesForCoercion */
public List<DataType> dataTypesForCoercion() {
return Stream.concat(whenClauses.stream(), defaultValue.map(Stream::of).orElseGet(Stream::empty))
.map(ExpressionTrait::getDataType)
.collect(ImmutableList.toImmutableList());
return this.dataTypesForCoercion.get();
}
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
@ -136,4 +143,16 @@ public class CaseWhen extends Expression {
}
return new CaseWhen(whenClauseList, defaultValue);
}
private Supplier<List<DataType>> computeDataTypesForCoercion() {
return Suppliers.memoize(() -> {
Builder<DataType> dataTypes = ImmutableList.builderWithExpectedSize(
whenClauses.size() + (defaultValue.isPresent() ? 1 : 0));
for (WhenClause whenClause : whenClauses) {
dataTypes.add(whenClause.getDataType());
}
defaultValue.ifPresent(expression -> dataTypes.add(expression.getDataType()));
return dataTypes.build();
});
}
}

View File

@ -39,7 +39,6 @@ import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.MapType;
import org.apache.doris.nereids.types.StructField;
import org.apache.doris.nereids.types.StructType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
@ -47,7 +46,6 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@ -67,36 +65,36 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
protected Expression(Expression... children) {
super(children);
depth = Arrays.stream(children)
.mapToInt(e -> e.depth)
.max().orElse(0) + 1;
width = Arrays.stream(children)
.mapToInt(e -> e.width)
.sum() + (children.length == 0 ? 1 : 0);
int maxChildDepth = 0;
int sumChildWidth = 0;
for (int i = 0; i < children.length; ++i) {
Expression child = children[i];
maxChildDepth = Math.max(child.depth, maxChildDepth);
sumChildWidth += child.width;
}
this.depth = maxChildDepth + 1;
this.width = sumChildWidth + ((children.length == 0) ? 1 : 0);
checkLimit();
this.inferred = false;
}
protected Expression(List<Expression> children) {
super(children);
depth = children.stream()
.mapToInt(e -> e.depth)
.max().orElse(0) + 1;
width = children.stream()
.mapToInt(e -> e.width)
.sum() + (children.isEmpty() ? 1 : 0);
checkLimit();
this.inferred = false;
this(children, false);
}
protected Expression(List<Expression> children, boolean inferred) {
super(children);
depth = children.stream()
.mapToInt(e -> e.depth)
.max().orElse(0) + 1;
width = children.stream()
.mapToInt(e -> e.width)
.sum() + (children.isEmpty() ? 1 : 0);
int maxChildDepth = 0;
int sumChildWidth = 0;
for (int i = 0; i < children.size(); ++i) {
Expression child = children.get(i);
maxChildDepth = Math.max(child.depth, maxChildDepth);
sumChildWidth += child.width;
}
this.depth = maxChildDepth + 1;
this.width = sumChildWidth + ((children.isEmpty()) ? 1 : 0);
checkLimit();
this.inferred = inferred;
}
@ -130,8 +128,8 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
*/
public TypeCheckResult checkInputDataTypes() {
// check all of its children recursively.
for (Expression expression : this.children) {
TypeCheckResult childResult = expression.checkInputDataTypes();
for (Expression child : this.children) {
TypeCheckResult childResult = child.checkInputDataTypes();
if (childResult.failed()) {
return childResult;
}
@ -180,21 +178,20 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
}
private boolean checkPrimitiveInputDataTypesWithExpectType(DataType input, DataType expected) {
// These type will throw exception when invoke toCatalogDataType()
if (expected instanceof AnyDataType) {
return expected.acceptsType(input);
// support fast check the case: input=TinyIntType, expected=NumericType, for example: `1 + 1`.
// if no this check, there will have an exception when invoke NumericType.toCatalogDataType,
// when there has lots of expression, the exception become the bottleneck, because an exception
// need to record the whole StackFrame.
if (expected.acceptsType(input)) {
return true;
}
// TODO: complete the cast logic like FunctionCallExpr.analyzeImpl
boolean legacyCastCompatible = false;
try {
legacyCastCompatible = input.toCatalogDataType().matchesType(expected.toCatalogDataType());
return input.toCatalogDataType().matchesType(expected.toCatalogDataType());
} catch (Throwable t) {
// ignore.
}
if (!legacyCastCompatible && !expected.acceptsType(input)) {
return false;
}
return true;
}
private TypeCheckResult checkInputDataTypesWithExpectTypes(

View File

@ -22,6 +22,7 @@ import org.apache.doris.catalog.TableIf;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.trees.plans.algebra.Relation;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
@ -101,7 +102,8 @@ public class SlotReference extends Slot {
this.exprId = exprId;
this.name = name;
this.dataType = dataType;
this.qualifier = ImmutableList.copyOf(Objects.requireNonNull(qualifier, "qualifier can not be null"));
this.qualifier = Utils.fastToImmutableList(
Objects.requireNonNull(qualifier, "qualifier can not be null"));
this.nullable = nullable;
this.table = table;
this.column = column;

View File

@ -38,9 +38,9 @@ import org.apache.doris.nereids.util.ResponsibilityChain;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.math.BigDecimal;
import java.util.List;
@ -233,7 +233,7 @@ public class ComputeSignatureHelper {
public static FunctionSignature implementAnyDataTypeWithOutIndex(
FunctionSignature signature, List<Expression> arguments) {
// collect all any data type with index
List<DataType> newArgTypes = Lists.newArrayList();
List<DataType> newArgTypes = Lists.newArrayListWithCapacity(arguments.size());
for (int i = 0; i < arguments.size(); i++) {
DataType sigType;
if (i >= signature.argumentsTypes.size()) {
@ -267,10 +267,19 @@ public class ComputeSignatureHelper {
collectAnyDataType(sigType, expressionType, indexToArgumentTypes);
}
// if all any data type's expression is NULL, we should use follow to any data type to do type coercion
Set<Integer> allNullTypeIndex = indexToArgumentTypes.entrySet().stream()
.filter(entry -> entry.getValue().stream().allMatch(NullType.class::isInstance))
.map(Entry::getKey)
.collect(ImmutableSet.toImmutableSet());
Set<Integer> allNullTypeIndex = Sets.newHashSetWithExpectedSize(indexToArgumentTypes.size());
for (Entry<Integer, List<DataType>> entry : indexToArgumentTypes.entrySet()) {
boolean allIsNullType = true;
for (DataType dataType : entry.getValue()) {
if (!(dataType instanceof NullType)) {
allIsNullType = false;
break;
}
}
if (allIsNullType) {
allNullTypeIndex.add(entry.getKey());
}
}
if (!allNullTypeIndex.isEmpty()) {
for (int i = 0; i < arguments.size(); i++) {
DataType sigType;
@ -297,7 +306,7 @@ public class ComputeSignatureHelper {
}
// replace any data type and follow to any data type with real data type
List<DataType> newArgTypes = Lists.newArrayList();
List<DataType> newArgTypes = Lists.newArrayListWithCapacity(signature.argumentsTypes.size());
for (DataType sigType : signature.argumentsTypes) {
newArgTypes.add(replaceAnyDataType(sigType, indexToCommonTypes));
}
@ -324,10 +333,18 @@ public class ComputeSignatureHelper {
if (computeSignature instanceof ComputePrecision) {
return ((ComputePrecision) computeSignature).computePrecision(signature);
}
if (signature.argumentsTypes.stream().anyMatch(TypeCoercionUtils::hasDateTimeV2Type)) {
boolean hasDateTimeV2Type = false;
boolean hasDecimalV3Type = false;
for (DataType argumentsType : signature.argumentsTypes) {
hasDateTimeV2Type |= TypeCoercionUtils.hasDateTimeV2Type(argumentsType);
hasDecimalV3Type |= TypeCoercionUtils.hasDecimalV3Type(argumentsType);
}
if (hasDateTimeV2Type) {
signature = defaultDateTimeV2PrecisionPromotion(signature, arguments);
}
if (signature.argumentsTypes.stream().anyMatch(TypeCoercionUtils::hasDecimalV3Type)) {
if (hasDecimalV3Type) {
// do decimal v3 precision
signature = defaultDecimalV3PrecisionPromotion(signature, arguments);
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.analyzer.ComplexDataType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NullType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
@ -49,6 +50,9 @@ public interface ExplicitlyCastableSignature extends ComputeSignature {
if (realType instanceof NullType) {
return true;
}
if (signatureType instanceof ComplexDataType && !(realType instanceof ComplexDataType)) {
return false;
}
try {
// TODO: copy canCastTo method to DataType
return Type.canCastTo(realType.toCatalogDataType(), signatureType.toCatalogDataType());

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.analyzer.ComplexDataType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.types.coercion.FollowToAnyDataType;
@ -45,6 +46,9 @@ public interface IdenticalSignature extends ComputeSignature {
if (signatureType instanceof AnyDataType || signatureType instanceof FollowToAnyDataType) {
return false;
}
if (signatureType instanceof ComplexDataType && !(realType instanceof ComplexDataType)) {
return false;
}
return realType.toCatalogDataType().matchesType(signatureType.toCatalogDataType());
} catch (Throwable t) {
// the signatureType maybe DataType and can not cast to catalog data type.

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.analyzer.ComplexDataType;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NullType;
@ -51,6 +52,9 @@ public interface ImplicitlyCastableSignature extends ComputeSignature {
if (realType instanceof NullType) {
return true;
}
if (signatureType instanceof ComplexDataType && !(realType instanceof ComplexDataType)) {
return false;
}
try {
// TODO: copy isImplicitlyCastable method to DataType
// TODO: resolve AnyDataType invoke toCatalogDataType
@ -68,8 +72,10 @@ public interface ImplicitlyCastableSignature extends ComputeSignature {
}
try {
List<DataType> allPromotions = realType.getAllPromotions();
if (allPromotions.stream().anyMatch(promotion -> isImplicitlyCastable(signatureType, promotion))) {
return true;
for (DataType promotion : allPromotions) {
if (isImplicitlyCastable(signatureType, promotion)) {
return true;
}
}
} catch (Throwable t) {
// the signatureType maybe DataType and can not cast to catalog data type.

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.analyzer.ComplexDataType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NullType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
@ -48,6 +49,9 @@ public interface NullOrIdenticalSignature extends ComputeSignature {
if (signatureType instanceof AnyDataType) {
return false;
}
if (signatureType instanceof ComplexDataType && !(realType instanceof ComplexDataType)) {
return false;
}
return realType.toCatalogDataType().matchesType(signatureType.toCatalogDataType());
} catch (Throwable t) {
// the signatureType maybe DataType and can not cast to catalog data type.

View File

@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.visitor;
import org.apache.doris.nereids.trees.expressions.Expression;
import java.util.ArrayList;
import java.util.List;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
/**
* Default implementation for expression rewriting, delegating to child expressions and rewrite current root
@ -30,20 +30,42 @@ public abstract class DefaultExpressionRewriter<C> extends ExpressionVisitor<Exp
@Override
public Expression visit(Expression expr, C context) {
return rewrite(this, expr, context);
return rewriteChildren(this, expr, context);
}
/** rewrite */
public static final <C> Expression rewrite(ExpressionVisitor<Expression, C> rewriter, Expression expr, C context) {
List<Expression> newChildren = new ArrayList<>(expr.arity());
boolean hasNewChildren = false;
for (Expression child : expr.children()) {
Expression newChild = child.accept(rewriter, context);
if (newChild != child) {
hasNewChildren = true;
/** rewriteChildren */
public static final <C> Expression rewriteChildren(
ExpressionVisitor<Expression, C> rewriter, Expression expr, C context) {
switch (expr.arity()) {
case 1: {
Expression originChild = expr.child(0);
Expression newChild = originChild.accept(rewriter, context);
return (originChild != newChild) ? expr.withChildren(ImmutableList.of(newChild)) : expr;
}
case 2: {
Expression originLeft = expr.child(0);
Expression newLeft = originLeft.accept(rewriter, context);
Expression originRight = expr.child(1);
Expression newRight = originRight.accept(rewriter, context);
return (originLeft != newLeft || originRight != newRight)
? expr.withChildren(ImmutableList.of(newLeft, newRight))
: expr;
}
case 0: {
return expr;
}
default: {
boolean hasNewChildren = false;
Builder<Expression> newChildren = ImmutableList.builderWithExpectedSize(expr.arity());
for (Expression child : expr.children()) {
Expression newChild = child.accept(rewriter, context);
if (newChild != child) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
return hasNewChildren ? expr.withChildren(newChildren.build()) : expr;
}
newChildren.add(newChild);
}
return hasNewChildren ? expr.withChildren(newChildren) : expr;
}
}

View File

@ -63,9 +63,9 @@ public class LogicalGenerate<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
super(PlanType.LOGICAL_GENERATE, groupExpression, logicalProperties, child);
this.generators = ImmutableList.copyOf(generators);
this.generatorOutput = ImmutableList.copyOf(generatorOutput);
this.expandColumnAlias = ImmutableList.copyOf(expandColumnAlias);
this.generators = Utils.fastToImmutableList(generators);
this.generatorOutput = Utils.fastToImmutableList(generatorOutput);
this.expandColumnAlias = Utils.fastToImmutableList(expandColumnAlias);
}
public List<Function> getGenerators() {

View File

@ -158,9 +158,9 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
// Just use in withXXX method. Don't need check/copyOf()
super(PlanType.LOGICAL_JOIN, groupExpression, logicalProperties, children);
this.joinType = Objects.requireNonNull(joinType, "joinType can not be null");
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.otherJoinConjuncts = ImmutableList.copyOf(otherJoinConjuncts);
this.markJoinConjuncts = ImmutableList.copyOf(markJoinConjuncts);
this.hashJoinConjuncts = Utils.fastToImmutableList(hashJoinConjuncts);
this.otherJoinConjuncts = Utils.fastToImmutableList(otherJoinConjuncts);
this.markJoinConjuncts = Utils.fastToImmutableList(markJoinConjuncts);
this.hint = Objects.requireNonNull(hint, "hint can not be null");
if (joinReorderContext != null) {
this.joinReorderContext.copyFrom(joinReorderContext);

View File

@ -37,6 +37,7 @@ import org.apache.doris.nereids.util.Utils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.commons.lang3.tuple.Pair;
@ -331,14 +332,20 @@ public class LogicalOlapScan extends LogicalCatalogRelation implements OlapScan
if (selectedIndexId != ((OlapTable) table).getBaseIndexId()) {
return getOutputByIndex(selectedIndexId);
}
return table.getBaseSchema(true).stream().map(col -> {
if (cacheSlotWithSlotName.containsKey(Pair.of(selectedIndexId, col.getName()))) {
return cacheSlotWithSlotName.get(Pair.of(selectedIndexId, col.getName()));
List<Column> baseSchema = table.getBaseSchema(true);
Builder<Slot> slots = ImmutableList.builder();
for (Column col : baseSchema) {
Pair<Long, String> key = Pair.of(selectedIndexId, col.getName());
Slot slot = cacheSlotWithSlotName.get(key);
if (slot != null) {
slots.add(slot);
} else {
slot = SlotReference.fromColumn(table, col, qualified(), this);
cacheSlotWithSlotName.put(key, slot);
slots.add(slot);
}
Slot slot = SlotReference.fromColumn(table, col, qualified(), this);
cacheSlotWithSlotName.put(Pair.of(selectedIndexId, col.getName()), slot);
return slot;
}).collect(ImmutableList.toImmutableList());
}
return slots.build();
}
@Override

View File

@ -38,6 +38,7 @@ import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import org.json.JSONObject;
@ -94,7 +95,7 @@ public class LogicalProject<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_
this.projects = projects.isEmpty()
? ImmutableList.of(ExpressionUtils.selectMinimumColumn(child.get(0).getOutput()))
: projects;
this.excepts = ImmutableList.copyOf(excepts);
this.excepts = Utils.fastToImmutableList(excepts);
this.isDistinct = isDistinct;
this.canEliminate = canEliminate;
}
@ -119,9 +120,11 @@ public class LogicalProject<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_
@Override
public List<Slot> computeOutput() {
return projects.stream()
.map(NamedExpression::toSlot)
.collect(ImmutableList.toImmutableList());
Builder<Slot> slots = ImmutableList.builderWithExpectedSize(projects.size());
for (NamedExpression project : projects) {
slots.add(project.toSlot());
}
return slots.build();
}
@Override
@ -170,7 +173,7 @@ public class LogicalProject<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_
@Override
public LogicalProject<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalProject<>(projects, excepts, isDistinct, canEliminate, ImmutableList.copyOf(children));
return new LogicalProject<>(projects, excepts, isDistinct, canEliminate, Utils.fastToImmutableList(children));
}
@Override

View File

@ -41,9 +41,7 @@ import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.qe.SessionVariable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@ -115,8 +113,9 @@ public abstract class LogicalSetOperation extends AbstractLogicalPlan implements
* Generate new output for SetOperation.
*/
public List<NamedExpression> buildNewOutputs() {
ImmutableList.Builder<NamedExpression> newOutputs = new Builder<>();
for (Slot slot : resetNullableForLeftOutputs()) {
List<Slot> slots = resetNullableForLeftOutputs();
ImmutableList.Builder<NamedExpression> newOutputs = ImmutableList.builderWithExpectedSize(slots.size());
for (Slot slot : slots) {
newOutputs.add(new SlotReference(slot.toSql(), slot.getDataType(), slot.nullable()));
}
return newOutputs.build();
@ -124,22 +123,28 @@ public abstract class LogicalSetOperation extends AbstractLogicalPlan implements
// If the right child is nullable, need to ensure that the left child is also nullable
private List<Slot> resetNullableForLeftOutputs() {
List<Slot> resetNullableForLeftOutputs = new ArrayList<>();
for (int i = 0; i < child(1).getOutput().size(); ++i) {
int rightChildOutputSize = child(1).getOutput().size();
ImmutableList.Builder<Slot> resetNullableForLeftOutputs
= ImmutableList.builderWithExpectedSize(rightChildOutputSize);
for (int i = 0; i < rightChildOutputSize; ++i) {
if (child(1).getOutput().get(i).nullable() && !child(0).getOutput().get(i).nullable()) {
resetNullableForLeftOutputs.add(child(0).getOutput().get(i).withNullable(true));
} else {
resetNullableForLeftOutputs.add(child(0).getOutput().get(i));
}
}
return ImmutableList.copyOf(resetNullableForLeftOutputs);
return resetNullableForLeftOutputs.build();
}
private List<List<NamedExpression>> castCommonDataTypeOutputs() {
List<NamedExpression> newLeftOutputs = new ArrayList<>();
List<NamedExpression> newRightOutputs = new ArrayList<>();
int childOutputSize = child(0).getOutput().size();
ImmutableList.Builder<NamedExpression> newLeftOutputs = ImmutableList.builderWithExpectedSize(
childOutputSize);
ImmutableList.Builder<NamedExpression> newRightOutputs = ImmutableList.builderWithExpectedSize(
childOutputSize
);
// Ensure that the output types of the left and right children are consistent and expand upward.
for (int i = 0; i < child(0).getOutput().size(); ++i) {
for (int i = 0; i < childOutputSize; ++i) {
Slot left = child(0).getOutput().get(i);
Slot right = child(1).getOutput().get(i);
DataType compatibleType = getAssignmentCompatibleType(left.getDataType(), right.getDataType());
@ -155,10 +160,7 @@ public abstract class LogicalSetOperation extends AbstractLogicalPlan implements
newRightOutputs.add((NamedExpression) newRight);
}
List<List<NamedExpression>> resultExpressions = new ArrayList<>();
resultExpressions.add(newLeftOutputs);
resultExpressions.add(newRightOutputs);
return ImmutableList.copyOf(resultExpressions);
return ImmutableList.of(newLeftOutputs.build(), newRightOutputs.build());
}
@Override

View File

@ -70,7 +70,7 @@ public class LogicalUnion extends LogicalSetOperation implements Union, OutputPr
List<List<NamedExpression>> constantExprsList, boolean hasPushedFilter, List<Plan> children) {
super(PlanType.LOGICAL_UNION, qualifier, outputs, childrenOutputs, children);
this.hasPushedFilter = hasPushedFilter;
this.constantExprsList = ImmutableList.copyOf(
this.constantExprsList = Utils.fastToImmutableList(
Objects.requireNonNull(constantExprsList, "constantExprsList should not be null"));
}
@ -81,7 +81,7 @@ public class LogicalUnion extends LogicalSetOperation implements Union, OutputPr
super(PlanType.LOGICAL_UNION, qualifier, outputs, childrenOutputs,
groupExpression, logicalProperties, children);
this.hasPushedFilter = hasPushedFilter;
this.constantExprsList = ImmutableList.copyOf(
this.constantExprsList = Utils.fastToImmutableList(
Objects.requireNonNull(constantExprsList, "constantExprsList should not be null"));
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.plans.visitor;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
@ -28,49 +29,64 @@ import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import java.util.Set;
/**
* Infer output column name when it refers an expression and not has an alias manually.
*/
public class InferPlanOutputAlias extends DefaultPlanVisitor<Void, ImmutableMultimap<ExprId, Integer>> {
public class InferPlanOutputAlias {
private final List<Slot> currentOutputs;
private final List<NamedExpression> finalOutputs;
private final Set<Integer> shouldProcessOutputIndex;
/** InferPlanOutputAlias */
public InferPlanOutputAlias(List<Slot> currentOutputs) {
this.currentOutputs = currentOutputs;
this.finalOutputs = new ArrayList<>(currentOutputs);
}
@Override
public Void visit(Plan plan, ImmutableMultimap<ExprId, Integer> currentExprIdAndIndexMap) {
List<Alias> aliasProjects = plan.getExpressions().stream()
.filter(expression -> expression instanceof Alias)
.map(Alias.class::cast)
.collect(Collectors.toList());
ImmutableSet<ExprId> currentOutputExprIdSet = currentExprIdAndIndexMap.keySet();
for (Alias projectItem : aliasProjects) {
ExprId exprId = projectItem.getExprId();
// Infer name when alias child is expression and alias's name is from child
if (currentOutputExprIdSet.contains(projectItem.getExprId())
&& projectItem.isNameFromChild()) {
String inferredAliasName = projectItem.child().getExpressionName();
ImmutableCollection<Integer> outPutExprIndexes = currentExprIdAndIndexMap.get(exprId);
// replace output name by inferred name
outPutExprIndexes.forEach(index -> {
Slot slot = currentOutputs.get(index);
finalOutputs.set(index, slot.withName("__" + inferredAliasName + "_" + index));
});
}
this.shouldProcessOutputIndex = new HashSet<>();
for (int i = 0; i < currentOutputs.size(); i++) {
shouldProcessOutputIndex.add(i);
}
return super.visit(plan, currentExprIdAndIndexMap);
}
public List<NamedExpression> getOutputs() {
/** infer */
public List<NamedExpression> infer(Plan plan, ImmutableMultimap<ExprId, Integer> currentExprIdAndIndexMap) {
ImmutableSet<ExprId> currentOutputExprIdSet = currentExprIdAndIndexMap.keySet();
// Breath First Search
plan.foreachBreath(childPlan -> {
if (shouldProcessOutputIndex.isEmpty()) {
return true;
}
for (Expression expression : ((Plan) childPlan).getExpressions()) {
if (!(expression instanceof Alias)) {
continue;
}
Alias projectItem = (Alias) expression;
ExprId exprId = projectItem.getExprId();
// Infer name when alias child is expression and alias's name is from child
if (currentOutputExprIdSet.contains(projectItem.getExprId())
&& projectItem.isNameFromChild()) {
String inferredAliasName = projectItem.child().getExpressionName();
ImmutableCollection<Integer> outputExprIndexes = currentExprIdAndIndexMap.get(exprId);
// replace output name by inferred name
for (Integer index : outputExprIndexes) {
Slot slot = currentOutputs.get(index);
finalOutputs.set(index, slot.withName("__" + inferredAliasName + "_" + index));
shouldProcessOutputIndex.remove(index);
if (shouldProcessOutputIndex.isEmpty()) {
// replace finished
return true;
}
}
}
}
// continue replace
return false;
});
return finalOutputs;
}
}

View File

@ -18,13 +18,14 @@
package org.apache.doris.nereids.types;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.analyzer.ComplexDataType;
import java.util.Objects;
/**
* Array type in Nereids.
*/
public class ArrayType extends DataType {
public class ArrayType extends DataType implements ComplexDataType {
public static final ArrayType SYSTEM_DEFAULT = new ArrayType(NullType.INSTANCE, true);

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.types;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.analyzer.ComplexDataType;
import org.apache.doris.nereids.annotation.Developing;
import java.util.Objects;
@ -26,7 +27,7 @@ import java.util.Objects;
* Struct type in Nereids.
*/
@Developing
public class MapType extends DataType {
public class MapType extends DataType implements ComplexDataType {
public static final MapType SYSTEM_DEFAULT = new MapType();

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.types;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.analyzer.ComplexDataType;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
@ -36,7 +37,7 @@ import java.util.stream.Collectors;
* Struct type in Nereids.
*/
@Developing
public class StructType extends DataType {
public class StructType extends DataType implements ComplexDataType {
public static final StructType SYSTEM_DEFAULT = new StructType();

View File

@ -59,6 +59,7 @@ import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
@ -67,6 +68,7 @@ import com.google.common.collect.Sets;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -188,9 +190,9 @@ public class ExpressionUtils {
*/
public static Expression combine(Class<? extends Expression> type, Collection<Expression> expressions) {
/*
* (AB) (CD) E ((AB)(CD)) E (((AB)(CD))E)
* ▲ ▲ ▲ ▲ ▲
* │ │ │ │ │
* (AB) (CD) E ((AB)(CD)) E (((AB)(CD))E)
* ▲ ▲ ▲ ▲ ▲
* │ │ │ │ │
* A B C D E ──► A B C D E ──► (AB) (CD) E ──► ((AB)(CD)) E ──► (((AB)(CD))E)
*/
Preconditions.checkArgument(type == And.class || type == Or.class);
@ -223,8 +225,7 @@ public class ExpressionUtils {
}
/**
* Replace the slot in expressions with the lineage identifier from
* specifiedbaseTable sets or target table types
* Replace the slot in expressions with the lineage identifier from specifiedbaseTable sets or target table types
* example as following:
* select a + 10 as a1, d from (
* select b - 5 as a, d from table
@ -241,9 +242,9 @@ public class ExpressionUtils {
}
ExpressionLineageReplacer.ExpressionReplaceContext replaceContext =
new ExpressionLineageReplacer.ExpressionReplaceContext(
expressions.stream().map(Expression.class::cast).collect(Collectors.toList()),
targetTypes,
tableIdentifiers);
expressions.stream().map(Expression.class::cast).collect(Collectors.toList()),
targetTypes,
tableIdentifiers);
plan.accept(ExpressionLineageReplacer.INSTANCE, replaceContext);
// Replace expressions by expression map
@ -277,10 +278,8 @@ public class ExpressionUtils {
}
/**
* Check whether the input expression is a
* {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a
* {@link org.apache.doris.nereids.trees.expressions.Slot}
* Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot}
* <p>
* for example:
* - SlotReference to a column:
@ -290,8 +289,7 @@ public class ExpressionUtils {
* cast(cast(int_col as long) as string)
*
* @param expr input expression
* @return Return Optional[ExprId] of underlying slot reference if input
* expression is a slot or cast on slot.
* @return Return Optional[ExprId] of underlying slot reference if input expression is a slot or cast on slot.
* Otherwise, return empty optional result.
*/
public static Optional<ExprId> isSlotOrCastOnSlot(Expression expr) {
@ -299,10 +297,8 @@ public class ExpressionUtils {
}
/**
* Check whether the input expression is a
* {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a
* {@link org.apache.doris.nereids.trees.expressions.Slot}
* Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot}
*/
public static Optional<Slot> extractSlotOrCastOnSlot(Expression expr) {
while (expr instanceof Cast) {
@ -317,23 +313,35 @@ public class ExpressionUtils {
}
/**
* Generate replaceMap Slot -> Expression from NamedExpression[Expression as
* name]
* Generate replaceMap Slot -> Expression from NamedExpression[Expression as name]
*/
public static Map<Slot, Expression> generateReplaceMap(List<NamedExpression> namedExpressions) {
return namedExpressions
.stream()
.filter(Alias.class::isInstance)
.collect(
Collectors.toMap(
NamedExpression::toSlot,
// Avoid cast to alias, retrieving the first child expression.
alias -> alias.child(0)));
ImmutableMap.Builder<Slot, Expression> replaceMap = ImmutableMap.builderWithExpectedSize(
namedExpressions.size() * 2);
for (NamedExpression namedExpression : namedExpressions) {
if (namedExpression instanceof Alias) {
// Avoid cast to alias, retrieving the first child expression.
replaceMap.put(namedExpression.toSlot(), namedExpression.child(0));
}
}
return replaceMap.build();
}
/**
* Replace expression node in the expression tree by `replaceMap` in top-down
* manner.
* replace NameExpression.
*/
public static NamedExpression replaceNameExpression(NamedExpression expr,
Map<? extends Expression, ? extends Expression> replaceMap) {
Expression newExpr = replace(expr, replaceMap);
if (newExpr instanceof NamedExpression) {
return (NamedExpression) newExpr;
} else {
return new Alias(expr.getExprId(), newExpr, expr.getName());
}
}
/**
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
* For example.
* <pre>
* input expression: a > 1
@ -344,20 +352,10 @@ public class ExpressionUtils {
* </pre>
*/
public static Expression replace(Expression expr, Map<? extends Expression, ? extends Expression> replaceMap) {
return expr.accept(ExpressionReplacer.INSTANCE, replaceMap);
}
/**
* replace NameExpression.
*/
public static NamedExpression replace(NamedExpression expr,
Map<? extends Expression, ? extends Expression> replaceMap) {
Expression newExpr = expr.accept(ExpressionReplacer.INSTANCE, replaceMap);
if (newExpr instanceof NamedExpression) {
return (NamedExpression) newExpr;
} else {
return new Alias(expr.getExprId(), newExpr, expr.getName());
}
return expr.rewriteDownShortCircuit(e -> {
Expression replacedExpr = replaceMap.get(e);
return replacedExpr == null ? e : replacedExpr;
});
}
public static List<Expression> replace(List<Expression> exprs,
@ -375,21 +373,20 @@ public class ExpressionUtils {
}
/**
* Replace expression node in the expression tree by `replaceMap` in top-down
* manner.
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
*/
public static List<NamedExpression> replaceNamedExpressions(List<NamedExpression> namedExpressions,
Map<? extends Expression, ? extends Expression> replaceMap) {
return namedExpressions.stream()
.map(namedExpression -> {
NamedExpression newExpr = replace(namedExpression, replaceMap);
if (newExpr.getExprId().equals(namedExpression.getExprId())) {
return newExpr;
} else {
return new Alias(namedExpression.getExprId(), newExpr, namedExpression.getName());
}
})
.collect(ImmutableList.toImmutableList());
Builder<NamedExpression> replaceExprs = ImmutableList.builderWithExpectedSize(namedExpressions.size());
for (NamedExpression namedExpression : namedExpressions) {
NamedExpression newExpr = replaceNameExpression(namedExpression, replaceMap);
if (newExpr.getExprId().equals(namedExpression.getExprId())) {
replaceExprs.add(newExpr);
} else {
replaceExprs.add(new Alias(namedExpression.getExprId(), newExpr, namedExpression.getName()));
}
}
return replaceExprs.build();
}
public static <E extends Expression> List<E> rewriteDownShortCircuit(
@ -489,19 +486,19 @@ public class ExpressionUtils {
public static boolean canInferNotNullForMarkSlot(Expression predicate) {
/*
* assume predicate is from LogicalFilter
* the idea is replacing each mark join slot with null and false literal then
* run FoldConstant rule
* the idea is replacing each mark join slot with null and false literal then run FoldConstant rule
* if the evaluate result are:
* 1. all true
* 2. all null and false (in logicalFilter, we discard both null and false
* values)
* 2. all null and false (in logicalFilter, we discard both null and false values)
* the mark slot can be non-nullable boolean
* and in semi join, we can safely change the mark conjunct to hash conjunct
*/
ImmutableList<Literal> literals = ImmutableList.of(new NullLiteral(BooleanType.INSTANCE), BooleanLiteral.FALSE);
List<MarkJoinSlotReference> markJoinSlotReferenceList = ((Set<MarkJoinSlotReference>) predicate
.collect(MarkJoinSlotReference.class::isInstance)).stream()
.collect(Collectors.toList());
ImmutableList<Literal> literals =
ImmutableList.of(new NullLiteral(BooleanType.INSTANCE), BooleanLiteral.FALSE);
List<MarkJoinSlotReference> markJoinSlotReferenceList =
((Set<MarkJoinSlotReference>) predicate
.collect(MarkJoinSlotReference.class::isInstance)).stream()
.collect(Collectors.toList());
int markSlotSize = markJoinSlotReferenceList.size();
int maxMarkSlotCount = 4;
// if the conjunct has mark slot, and maximum 4 mark slots(for performance)
@ -510,9 +507,9 @@ public class ExpressionUtils {
boolean meetTrue = false;
boolean meetNullOrFalse = false;
/*
* markSlotSize = 1 -> loopCount = 2 ---- 0, 1
* markSlotSize = 2 -> loopCount = 4 ---- 00, 01, 10, 11
* markSlotSize = 3 -> loopCount = 8 ---- 000, 001, 010, 011, 100, 101, 110, 111
* markSlotSize = 1 -> loopCount = 2 ---- 0, 1
* markSlotSize = 2 -> loopCount = 4 ---- 00, 01, 10, 11
* markSlotSize = 3 -> loopCount = 8 ---- 000, 001, 010, 011, 100, 101, 110, 111
* markSlotSize = 4 -> loopCount = 16 ---- 0000, 0001, ... 1111
*/
int loopCount = 2 << markSlotSize;
@ -583,8 +580,7 @@ public class ExpressionUtils {
}
/**
* infer notNulls slot from predicate but these slots must be in the given
* slots.
* infer notNulls slot from predicate but these slots must be in the given slots.
*/
public static Set<Expression> inferNotNull(Set<Expression> predicates, Set<Slot> slots,
CascadesContext cascadesContext) {
@ -614,7 +610,7 @@ public class ExpressionUtils {
return anyMatch(expressions, type::isInstance);
}
public static <E> Set<E> collect(List<? extends Expression> expressions,
public static <E> Set<E> collect(Collection<? extends Expression> expressions,
Predicate<TreeNode<Expression>> predicate) {
return expressions.stream()
.flatMap(expr -> expr.<Set<E>>collect(predicate).stream())
@ -654,7 +650,7 @@ public class ExpressionUtils {
.collect(Collectors.toSet());
}
public static <E> List<E> collectAll(List<? extends Expression> expressions,
public static <E> List<E> collectAll(Collection<? extends Expression> expressions,
Predicate<TreeNode<Expression>> predicate) {
return expressions.stream()
.flatMap(expr -> expr.<Set<E>>collect(predicate).stream())
@ -764,18 +760,18 @@ public class ExpressionUtils {
*/
public static boolean checkSlotConstant(Slot slot, Set<Expression> predicates) {
return predicates.stream().anyMatch(predicate -> {
if (predicate instanceof EqualTo) {
EqualTo equalTo = (EqualTo) predicate;
return (equalTo.left() instanceof Literal && equalTo.right().equals(slot))
|| (equalTo.right() instanceof Literal && equalTo.left().equals(slot));
}
return false;
});
if (predicate instanceof EqualTo) {
EqualTo equalTo = (EqualTo) predicate;
return (equalTo.left() instanceof Literal && equalTo.right().equals(slot))
|| (equalTo.right() instanceof Literal && equalTo.left().equals(slot));
}
return false;
}
);
}
/**
* Check the expression is inferred or not, if inferred return true, nor return
* false
* Check the expression is inferred or not, if inferred return true, nor return false
*/
public static boolean isInferred(Expression expression) {
return expression.accept(new DefaultExpressionVisitor<Boolean, Void>() {
@ -794,4 +790,17 @@ public class ExpressionUtils {
}
}, null);
}
/** distinctSlotByName */
public static List<Slot> distinctSlotByName(List<Slot> slots) {
Set<String> existSlotNames = new HashSet<>(slots.size() * 2);
Builder<Slot> distinctSlots = ImmutableList.builderWithExpectedSize(slots.size());
for (Slot slot : slots) {
String name = slot.getName();
if (existSlotNames.add(name)) {
distinctSlots.add(slot);
}
}
return distinctSlots.build();
}
}

View File

@ -45,6 +45,7 @@ import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
@ -290,8 +291,11 @@ public class JoinUtils {
}
private static List<Slot> applyNullable(List<Slot> slots, boolean nullable) {
return slots.stream().map(o -> o.withNullable(nullable))
.collect(ImmutableList.toImmutableList());
Builder<Slot> newSlots = ImmutableList.builderWithExpectedSize(slots.size());
for (Slot slot : slots) {
newSlots.add(slot.withNullable(nullable));
}
return newSlots.build();
}
private static Map<Slot, Slot> mapPrimaryToForeign(ImmutableEqualSet<Slot> equivalenceSet,

View File

@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
@ -125,6 +126,23 @@ public class PlanUtils {
return resultSet;
}
/** fastGetChildrenOutput */
public static List<Slot> fastGetChildrenOutputs(List<Plan> children) {
int outputNum = 0;
// child.output is cached by AbstractPlan.logicalProperties,
// we can compute output num without the overhead of re-compute output
for (Plan child : children) {
List<Slot> output = child.getOutput();
outputNum += output.size();
}
// generate output list only copy once and without resize the list
Builder<Slot> output = ImmutableList.builderWithExpectedSize(outputNum);
for (Plan child : children) {
output.addAll(child.getOutput());
}
return output.build();
}
/**
* collect non_window_agg_func
*/

View File

@ -21,6 +21,7 @@ import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Config;
import org.apache.doris.nereids.analyzer.ComplexDataType;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Add;
@ -119,9 +120,7 @@ import java.math.BigInteger;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Utils for type coercion.
@ -595,7 +594,7 @@ public class TypeCoercionUtils {
private static List<Optional<DataType>> getInputImplicitCastTypes(
List<Expression> inputs, List<DataType> expectedTypes) {
Builder<Optional<DataType>> implicitCastTypes = ImmutableList.builder();
Builder<Optional<DataType>> implicitCastTypes = ImmutableList.builderWithExpectedSize(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
DataType argType = inputs.get(i).getDataType();
DataType expectedType = expectedTypes.get(i);
@ -765,15 +764,16 @@ public class TypeCoercionUtils {
binaryArithmetic = TypeCoercionUtils.processCharacterLiteralInBinaryOperator(binaryArithmetic, left, right);
// check string literal can cast to double
binaryArithmetic.children().stream().filter(e -> e instanceof StringLikeLiteral)
.forEach(expr -> {
try {
new BigDecimal(((StringLikeLiteral) expr).getStringValue());
} catch (NumberFormatException e) {
throw new IllegalStateException(String.format(
"string literal %s cannot be cast to double", expr.toSql()));
}
});
for (Expression expr : binaryArithmetic.children()) {
if (expr instanceof StringLikeLiteral) {
try {
new BigDecimal(((StringLikeLiteral) expr).getStringValue());
} catch (NumberFormatException e) {
throw new IllegalStateException(String.format(
"string literal %s cannot be cast to double", expr.toSql()));
}
}
}
// 1. choose default numeric type for left and right
DataType t1 = TypeCoercionUtils.getNumResultType(left.getDataType());
@ -1106,14 +1106,13 @@ public class TypeCoercionUtils {
private static Optional<DataType> findWiderTypeForTwoForComparison(
DataType left, DataType right, boolean intStringToString) {
// TODO: need to rethink how to handle char and varchar to return char or varchar as much as possible.
return Stream
.<Supplier<Optional<DataType>>>of(
() -> findCommonComplexTypeForComparison(left, right, intStringToString),
() -> findCommonPrimitiveTypeForComparison(left, right, intStringToString))
.map(Supplier::get)
.filter(Optional::isPresent)
.map(Optional::get)
.findFirst();
if (left instanceof ComplexDataType) {
Optional<DataType> commonType = findCommonComplexTypeForComparison(left, right, intStringToString);
if (commonType.isPresent()) {
return commonType;
}
}
return findCommonPrimitiveTypeForComparison(left, right, intStringToString);
}
/**
@ -1310,20 +1309,22 @@ public class TypeCoercionUtils {
Map<Boolean, List<DataType>> partitioned = dataTypes.stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasCharacterType));
List<DataType> needTypeCoercion = Lists.newArrayList(Sets.newHashSet(partitioned.get(true)));
if (needTypeCoercion.size() > 1 || !partitioned.get(false).isEmpty()) {
needTypeCoercion = needTypeCoercion.stream()
.map(TypeCoercionUtils::replaceCharacterToString)
.collect(Collectors.toList());
List<DataType> nonCharTypes = partitioned.get(false);
if (needTypeCoercion.size() > 1 || !nonCharTypes.isEmpty()) {
needTypeCoercion = Utils.fastMapList(
needTypeCoercion, nonCharTypes.size(), TypeCoercionUtils::replaceCharacterToString);
}
needTypeCoercion.addAll(partitioned.get(false));
return needTypeCoercion.stream().map(Optional::of).reduce(Optional.of(NullType.INSTANCE),
(r, c) -> {
if (r.isPresent() && c.isPresent()) {
return findWiderTypeForTwoForCaseWhen(r.get(), c.get());
} else {
return Optional.empty();
}
});
needTypeCoercion.addAll(nonCharTypes);
DataType commonType = NullType.INSTANCE;
for (DataType dataType : needTypeCoercion) {
Optional<DataType> newCommonType = findWiderTypeForTwoForCaseWhen(commonType, dataType);
if (!newCommonType.isPresent()) {
return Optional.empty();
}
commonType = newCommonType.get();
}
return Optional.of(commonType);
}
/**
@ -1332,14 +1333,11 @@ public class TypeCoercionUtils {
@Developing
private static Optional<DataType> findWiderTypeForTwoForCaseWhen(DataType left, DataType right) {
// TODO: need to rethink how to handle char and varchar to return char or varchar as much as possible.
return Stream
.<Supplier<Optional<DataType>>>of(
() -> findCommonComplexTypeForCaseWhen(left, right),
() -> findCommonPrimitiveTypeForCaseWhen(left, right))
.map(Supplier::get)
.filter(Optional::isPresent)
.map(Optional::get)
.findFirst();
Optional<DataType> commonType = findCommonComplexTypeForCaseWhen(left, right);
if (commonType.isPresent()) {
return commonType;
}
return findCommonPrimitiveTypeForCaseWhen(left, right);
}
/**
@ -1585,7 +1583,7 @@ public class TypeCoercionUtils {
*/
public static BoundFunction fillJsonValueModifyTypeArgument(BoundFunction function) {
List<Expression> arguments = function.getArguments();
List<Expression> newArguments = Lists.newArrayList();
List<Expression> newArguments = Lists.newArrayListWithCapacity(arguments.size() + 1);
StringBuilder jsonTypeStr = new StringBuilder();
for (int i = 0; i < arguments.size(); i++) {
Expression argument = arguments.get(i);

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import com.google.common.base.CaseFormat;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;
@ -34,6 +35,8 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -294,4 +297,70 @@ public class Utils {
}
return false;
}
public static <I, O> List<O> fastMapList(List<I> list, int additionSize, Function<I, O> transformer) {
List<O> newList = Lists.newArrayListWithCapacity(list.size() + additionSize);
for (I input : list) {
newList.add(transformer.apply(input));
}
return newList;
}
/** fastToImmutableList */
public static <E> ImmutableList<E> fastToImmutableList(E[] array) {
switch (array.length) {
case 0:
return ImmutableList.of();
case 1:
return ImmutableList.of(array[0]);
default:
// NOTE: ImmutableList.copyOf(array) has additional clone of the array, so here we
// direct generate a ImmutableList
Builder<E> copyChildren = ImmutableList.builderWithExpectedSize(array.length);
for (E child : array) {
copyChildren.add(child);
}
return copyChildren.build();
}
}
/** fastToImmutableList */
public static <E> ImmutableList<E> fastToImmutableList(List<? extends E> originList) {
if (originList instanceof ImmutableList) {
return (ImmutableList<E>) originList;
}
switch (originList.size()) {
case 0: return ImmutableList.of();
case 1: return ImmutableList.of(originList.get(0));
default: {
// NOTE: ImmutableList.copyOf(list) has additional clone of the list, so here we
// direct generate a ImmutableList
Builder<E> copyChildren = ImmutableList.builderWithExpectedSize(originList.size());
copyChildren.addAll(originList);
return copyChildren.build();
}
}
}
/** reverseImmutableList */
public static <E> ImmutableList<E> reverseImmutableList(List<? extends E> list) {
Builder<E> reverseList = ImmutableList.builderWithExpectedSize(list.size());
for (int i = list.size() - 1; i >= 0; i--) {
reverseList.add(list.get(i));
}
return reverseList.build();
}
/** filterImmutableList */
public static <E> ImmutableList<E> filterImmutableList(List<? extends E> list, Predicate<E> filter) {
Builder<E> newList = ImmutableList.builderWithExpectedSize(list.size());
for (int i = 0; i < list.size(); i++) {
E item = list.get(i);
if (filter.test(item)) {
newList.add(item);
}
}
return newList.build();
}
}