[enhancement](nereids) make filter node and join node work in Nereids (#10605)

enhancement
- add functions `finalizeForNereids` and `finalizeImplForNereids` in stale expression to generate some attributes using in BE.
- remove unnecessary parameter `Analyzer` in function `getBuiltinFunction`
- swap join condition if its left hand expression related to right table
- change join physical implementation to broadcast hash join 
- add push predicate rule into planner

fix
- swap join children visit order to ensure the last fragment is root
- avoid visit join left child twice

known issues
- expression compute will generate a wrong answer when expression include arithmetic with two literal children.
This commit is contained in:
morrySnow
2022-07-05 18:23:00 +08:00
committed by GitHub
parent 3b0ddd7ae0
commit 589ab06b5c
21 changed files with 292 additions and 114 deletions

View File

@ -321,8 +321,7 @@ public class ArithmeticExpr extends Expr {
} else {
type = t;
}
fn = getBuiltinFunction(
analyzer, op.getName(), collectChildReturnTypes(), Function.CompareMode.IS_SUPERTYPE_OF);
fn = getBuiltinFunction(op.getName(), collectChildReturnTypes(), Function.CompareMode.IS_SUPERTYPE_OF);
if (fn == null) {
Preconditions.checkState(false, String.format("No match for op with operand types", toSql()));
}
@ -406,8 +405,7 @@ public class ArithmeticExpr extends Expr {
"Unknown arithmetic operation " + op.toString() + " in: " + this.toSql());
break;
}
fn = getBuiltinFunction(analyzer, op.name, collectChildReturnTypes(),
Function.CompareMode.IS_IDENTICAL);
fn = getBuiltinFunction(op.name, collectChildReturnTypes(), Function.CompareMode.IS_IDENTICAL);
if (fn == null) {
Preconditions.checkState(false, String.format(
"No match for vec function '%s' with operand types %s and %s", toSql(), t1, t2));
@ -420,8 +418,7 @@ public class ArithmeticExpr extends Expr {
if (getChild(0).getType().getPrimitiveType() != PrimitiveType.BIGINT) {
castChild(type, 0);
}
fn = getBuiltinFunction(
analyzer, op.getName(), collectChildReturnTypes(), Function.CompareMode.IS_SUPERTYPE_OF);
fn = getBuiltinFunction(op.getName(), collectChildReturnTypes(), Function.CompareMode.IS_SUPERTYPE_OF);
if (fn == null) {
Preconditions.checkState(false, String.format("No match for op with operand types", toSql()));
}
@ -467,8 +464,7 @@ public class ArithmeticExpr extends Expr {
break;
}
type = castBinaryOp(commonType);
fn = getBuiltinFunction(analyzer, fnName, collectChildReturnTypes(),
Function.CompareMode.IS_IDENTICAL);
fn = getBuiltinFunction(fnName, collectChildReturnTypes(), Function.CompareMode.IS_IDENTICAL);
if (fn == null) {
Preconditions.checkState(false, String.format(
"No match for '%s' with operand types %s and %s", toSql(), t1, t2));
@ -506,4 +502,16 @@ public class ArithmeticExpr extends Expr {
return 31 * super.hashCode() + Objects.hashCode(op);
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
if (op == Operator.BITNOT) {
fn = getBuiltinFunction(op.getName(), collectChildReturnTypes(), Function.CompareMode.IS_SUPERTYPE_OF);
} else {
fn = getBuiltinFunction(op.name, collectChildReturnTypes(), Function.CompareMode.IS_IDENTICAL);
}
if (fn == null) {
Preconditions.checkState(false, String.format("No match for op with operand types. %s", toSql()));
}
type = fn.getReturnType();
}
}

View File

@ -124,4 +124,9 @@ public class BetweenPredicate extends Predicate {
public int hashCode() {
return 31 * super.hashCode() + Boolean.hashCode(isNotBetween);
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
throw new AnalysisException("analyze between predicate for Nereids do not implementation.");
}
}

View File

@ -278,8 +278,8 @@ public class BinaryPredicate extends Predicate implements Writable {
//OpcodeRegistry.BuiltinFunction match = OpcodeRegistry.instance().getFunctionInfo(
// op.toFilterFunctionOp(), true, true, cmpType, cmpType);
try {
match = getBuiltinFunction(analyzer, op.name, collectChildReturnTypes(),
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
match = getBuiltinFunction(op.name, collectChildReturnTypes(),
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} catch (AnalysisException e) {
Preconditions.checkState(false);
}
@ -408,8 +408,7 @@ public class BinaryPredicate extends Predicate implements Writable {
this.opcode = op.getOpcode();
String opName = op.getName();
fn = getBuiltinFunction(analyzer, opName, collectChildReturnTypes(),
Function.CompareMode.IS_SUPERTYPE_OF);
fn = getBuiltinFunction(opName, collectChildReturnTypes(), Function.CompareMode.IS_SUPERTYPE_OF);
if (fn == null) {
Preconditions.checkState(false, String.format(
"No match for '%s' with operand types %s and %s", toSql()));
@ -697,4 +696,10 @@ public class BinaryPredicate extends Predicate implements Writable {
}
return hasNullableChild();
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
super.finalizeImplForNereids();
fn = getBuiltinFunction(op.name, collectChildReturnTypes(), Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
}
}

View File

@ -256,4 +256,9 @@ public class CompoundPredicate extends Predicate {
public boolean isNullable() {
return hasNullableChild();
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
}
}

View File

@ -40,4 +40,9 @@ public class DefaultValueExpr extends Expr {
public Expr clone() {
return null;
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
}
}

View File

@ -1617,8 +1617,7 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
* Looks up in the catalog the builtin for 'name' and 'argTypes'.
* Returns null if the function is not found.
*/
protected Function getBuiltinFunction(
Analyzer analyzer, String name, Type[] argTypes, Function.CompareMode mode)
protected Function getBuiltinFunction(String name, Type[] argTypes, Function.CompareMode mode)
throws AnalysisException {
FunctionName fnName = new FunctionName(name);
Function searchDesc = new Function(fnName, Arrays.asList(argTypes), Type.INVALID, false,
@ -1633,8 +1632,7 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
return f;
}
protected Function getTableFunction(String name, Type[] argTypes,
Function.CompareMode mode) {
protected Function getTableFunction(String name, Type[] argTypes, Function.CompareMode mode) {
FunctionName fnName = new FunctionName(name);
Function searchDesc = new Function(fnName, Arrays.asList(argTypes), Type.INVALID, false,
VectorizedUtil.isVectorized());
@ -1975,4 +1973,19 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
}
return true;
}
public final void finalizeForNereids() throws AnalysisException {
if (isAnalyzed()) {
return;
}
for (Expr child : children) {
child.finalizeForNereids();
}
finalizeImplForNereids();
analysisDone();
}
public void finalizeImplForNereids() throws AnalysisException {
throw new AnalysisException("analyze for Nereids do not implementation.");
}
}

View File

@ -793,15 +793,13 @@ public class FunctionCallExpr extends Expr {
* @throws AnalysisException
*/
public void analyzeImplForDefaultValue() throws AnalysisException {
fn = getBuiltinFunction(null, fnName.getFunction(), new Type[0],
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
fn = getBuiltinFunction(fnName.getFunction(), new Type[0], Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
type = fn.getReturnType();
for (int i = 0; i < children.size(); ++i) {
if (getChild(i).getType().isNull()) {
uncheckedCastChild(Type.BOOLEAN, i);
}
}
return;
}
@Override
@ -825,8 +823,7 @@ public class FunctionCallExpr extends Expr {
// There is no version of COUNT() that takes more than 1 argument but after
// the equal, we only need count(*).
// TODO: fix how we equal count distinct.
fn = getBuiltinFunction(analyzer, fnName.getFunction(), new Type[0],
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
fn = getBuiltinFunction(fnName.getFunction(), new Type[0], Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
type = fn.getReturnType();
// Make sure BE doesn't see any TYPE_NULL exprs
@ -854,7 +851,7 @@ public class FunctionCallExpr extends Expr {
if (!VectorizedUtil.isVectorized()) {
type = getChild(0).type.getMaxResolutionType();
}
fn = getBuiltinFunction(analyzer, fnName.getFunction(), new Type[]{type},
fn = getBuiltinFunction(fnName.getFunction(), new Type[]{type},
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else if (fnName.getFunction().equalsIgnoreCase("count_distinct")) {
Type compatibleType = this.children.get(0).getType();
@ -867,7 +864,7 @@ public class FunctionCallExpr extends Expr {
}
}
fn = getBuiltinFunction(analyzer, fnName.getFunction(), new Type[]{compatibleType},
fn = getBuiltinFunction(fnName.getFunction(), new Type[]{compatibleType},
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else if (fnName.getFunction().equalsIgnoreCase(FunctionSet.WINDOW_FUNNEL)) {
if (fnParams.exprs() == null || fnParams.exprs().size() < 4) {
@ -895,14 +892,14 @@ public class FunctionCallExpr extends Expr {
}
childTypes[i] = children.get(i).type;
}
fn = getBuiltinFunction(analyzer, fnName.getFunction(), childTypes,
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else if (fnName.getFunction().equalsIgnoreCase("if")) {
Type[] childTypes = collectChildReturnTypes();
Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[1], childTypes[2], true);
childTypes[1] = assignmentCompatibleType;
childTypes[2] = assignmentCompatibleType;
fn = getBuiltinFunction(analyzer, fnName.getFunction(), childTypes,
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else {
// now first find table function in table function sets
@ -917,7 +914,7 @@ public class FunctionCallExpr extends Expr {
// now first find function in built-in functions
if (Strings.isNullOrEmpty(fnName.getDb())) {
Type[] childTypes = collectChildReturnTypes();
fn = getBuiltinFunction(analyzer, fnName.getFunction(), childTypes,
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
}
@ -1257,4 +1254,9 @@ public class FunctionCallExpr extends Expr {
}
return result.toString();
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
super.finalizeImplForNereids();
}
}

View File

@ -73,8 +73,7 @@ public class GroupingFunctionCallExpr extends FunctionCallExpr {
}
Type[] childTypes = new Type[1];
childTypes[0] = Type.BIGINT;
fn = getBuiltinFunction(analyzer, getFnName().getFunction(), childTypes,
Function.CompareMode.IS_IDENTICAL);
fn = getBuiltinFunction(getFnName().getFunction(), childTypes, Function.CompareMode.IS_IDENTICAL);
this.type = fn.getReturnType();
}

View File

@ -199,7 +199,7 @@ public class InPredicate extends Predicate {
// argTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
opcode = isNotIn ? TExprOpcode.FILTER_NOT_IN : TExprOpcode.FILTER_IN;
} else {
fn = getBuiltinFunction(analyzer, isNotIn ? NOT_IN_ITERATE : IN_ITERATE,
fn = getBuiltinFunction(isNotIn ? NOT_IN_ITERATE : IN_ITERATE,
argTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
opcode = isNotIn ? TExprOpcode.FILTER_NEW_NOT_IN : TExprOpcode.FILTER_NEW_IN;
}
@ -319,4 +319,29 @@ public class InPredicate extends Predicate {
public boolean isNullable() {
return hasNullableChild();
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
super.finalizeImplForNereids();
boolean allConstant = true;
for (int i = 1; i < children.size(); ++i) {
if (!children.get(i).isConstant()) {
allConstant = false;
break;
}
}
// Only lookup fn_ if all subqueries have been rewritten. If the second child is a
// subquery, it will have type ArrayType, which cannot be resolved to a builtin
// function and will fail analysis.
Type[] argTypes = {getChild(0).type, getChild(1).type};
if (allConstant) {
// fn = getBuiltinFunction(analyzer, isNotIn ? NOT_IN_SET_LOOKUP : IN_SET_LOOKUP,
// argTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
opcode = isNotIn ? TExprOpcode.FILTER_NOT_IN : TExprOpcode.FILTER_IN;
} else {
fn = getBuiltinFunction(isNotIn ? NOT_IN_ITERATE : IN_ITERATE,
argTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
opcode = isNotIn ? TExprOpcode.FILTER_NEW_NOT_IN : TExprOpcode.FILTER_NEW_IN;
}
}
}

View File

@ -113,11 +113,9 @@ public class IsNullPredicate extends Predicate {
public void analyzeImpl(Analyzer analyzer) throws AnalysisException {
super.analyzeImpl(analyzer);
if (isNotNull) {
fn = getBuiltinFunction(
analyzer, IS_NOT_NULL, collectChildReturnTypes(), Function.CompareMode.IS_INDISTINGUISHABLE);
fn = getBuiltinFunction(IS_NOT_NULL, collectChildReturnTypes(), Function.CompareMode.IS_INDISTINGUISHABLE);
} else {
fn = getBuiltinFunction(
analyzer, IS_NULL, collectChildReturnTypes(), Function.CompareMode.IS_INDISTINGUISHABLE);
fn = getBuiltinFunction(IS_NULL, collectChildReturnTypes(), Function.CompareMode.IS_INDISTINGUISHABLE);
}
Preconditions.checkState(fn != null, "tupleisNull fn == NULL");
@ -156,4 +154,15 @@ public class IsNullPredicate extends Predicate {
}
return childValue instanceof NullLiteral ? new BoolLiteral(!isNotNull) : new BoolLiteral(isNotNull);
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
super.finalizeImplForNereids();
if (isNotNull) {
fn = getBuiltinFunction(IS_NOT_NULL, collectChildReturnTypes(), Function.CompareMode.IS_INDISTINGUISHABLE);
} else {
fn = getBuiltinFunction(IS_NULL, collectChildReturnTypes(), Function.CompareMode.IS_INDISTINGUISHABLE);
}
Preconditions.checkState(fn != null, "tupleisNull fn == NULL");
}
}

View File

@ -134,8 +134,8 @@ public class LikePredicate extends Predicate {
uncheckedCastChild(Type.VARCHAR, 0);
}
fn = getBuiltinFunction(analyzer, op.toString(),
collectChildReturnTypes(), Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
fn = getBuiltinFunction(op.toString(), collectChildReturnTypes(),
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
if (!getChild(1).getType().isNull() && getChild(1).isLiteral() && (op == Operator.REGEXP)) {
// let's make sure the pattern works
@ -154,4 +154,10 @@ public class LikePredicate extends Predicate {
return 31 * super.hashCode() + Objects.hashCode(op);
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
super.finalizeImplForNereids();
fn = getBuiltinFunction(op.toString(), collectChildReturnTypes(),
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
}
}

View File

@ -247,4 +247,9 @@ public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr
public boolean isNullable() {
return this instanceof NullLiteral;
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
}
}

View File

@ -156,4 +156,9 @@ public abstract class Predicate extends Expr {
public Pair<SlotId, SlotId> getEqSlots() {
return null;
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
type = Type.BOOLEAN;
}
}

View File

@ -441,4 +441,9 @@ public class SlotRef extends Expr {
Preconditions.checkNotNull(desc);
return desc.getIsNullable();
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
}
}

View File

@ -213,8 +213,8 @@ public class TimestampArithmeticExpr extends Expr {
(op == ArithmeticExpr.Operator.ADD) ? "ADD" : "SUB");
}
fn = getBuiltinFunction(analyzer, funcOpName.toLowerCase(),
collectChildReturnTypes(), Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
fn = getBuiltinFunction(funcOpName.toLowerCase(), collectChildReturnTypes(),
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
LOG.debug("fn is {} name is {}", fn, funcOpName);
}

View File

@ -198,4 +198,9 @@ public class TupleIsNullPredicate extends Predicate {
public boolean isNullable() {
return false;
}
@Override
public void finalizeImplForNereids() throws AnalysisException {
super.finalizeImplForNereids();
}
}

View File

@ -18,19 +18,27 @@
package org.apache.doris.nereids;
import org.apache.doris.analysis.DescriptorTable;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.analysis.TupleId;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.Id;
import org.apache.doris.common.UserException;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.jobs.AnalyzeRulesJob;
import org.apache.doris.nereids.jobs.OptimizeRulesJob;
import org.apache.doris.nereids.jobs.PredicatePushDownRulesJob;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
@ -40,10 +48,12 @@ import org.apache.doris.planner.ScanNode;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
@ -66,26 +76,42 @@ public class NereidsPlanner extends Planner {
if (!(queryStmt instanceof LogicalPlanAdapter)) {
throw new RuntimeException("Wrong type of queryStmt, expected: <? extends LogicalPlanAdapter>");
}
LogicalPlanAdapter logicalPlanAdapter = (LogicalPlanAdapter) queryStmt;
PhysicalPlan physicalPlan = plan(logicalPlanAdapter.getLogicalPlan(), new PhysicalProperties(), ctx);
PhysicalPlanTranslator physicalPlanTranslator = new PhysicalPlanTranslator();
PlanTranslatorContext planTranslatorContext = new PlanTranslatorContext();
physicalPlanTranslator.translatePlan(physicalPlan, planTranslatorContext);
scanNodeList = planTranslatorContext.getScanNodeList();
descTable = planTranslatorContext.getDescTable();
fragments = new ArrayList<>(planTranslatorContext.getPlanFragmentList());
PlanFragment root = fragments.get(fragments.size() - 1);
for (PlanFragment fragment : fragments) {
fragment.finalize(queryStmt);
}
root.resetOutputExprs(descTable.getTupleDesc(root.getPlanRoot().getTupleIds().get(0)));
Collections.reverse(fragments);
PlanFragment root = fragments.get(0);
// compute output exprs
Map<Integer, Expr> outputCandidates = Maps.newHashMap();
List<Expr> outputExprs = Lists.newArrayList();
for (TupleId tupleId : root.getPlanRoot().getTupleIds()) {
TupleDescriptor tupleDescriptor = descTable.getTupleDesc(tupleId);
for (SlotDescriptor slotDescriptor : tupleDescriptor.getSlots()) {
SlotRef slotRef = new SlotRef(slotDescriptor);
outputCandidates.put(slotDescriptor.getId().asInt(), slotRef);
}
}
physicalPlan.getOutput().stream().map(Slot::getExprId)
.map(Id::asInt).forEach(i -> outputExprs.add(outputCandidates.get(i)));
root.setOutputExprs(outputExprs);
root.getPlanRoot().convertToVectoriezd();
scanNodeList = planTranslatorContext.getScanNodeList();
logicalPlanAdapter.setResultExprs(root.getOutputExprs());
logicalPlanAdapter.setResultExprs(outputExprs);
ArrayList<String> columnLabelList = physicalPlan.getOutput().stream()
.map(NamedExpression::getName).collect(Collectors.toCollection(ArrayList::new));
logicalPlanAdapter.setColLabels(columnLabelList);
Collections.reverse(fragments);
}
/**
@ -118,6 +144,9 @@ public class NereidsPlanner extends Planner {
AnalyzeRulesJob analyzeRulesJob = new AnalyzeRulesJob(plannerContext);
analyzeRulesJob.execute();
PredicatePushDownRulesJob predicatePushDownRulesJob = new PredicatePushDownRulesJob(plannerContext);
predicatePushDownRulesJob.execute();
OptimizeRulesJob optimizeRulesJob = new OptimizeRulesJob(plannerContext);
optimizeRulesJob.execute();

View File

@ -28,6 +28,7 @@ import org.apache.doris.analysis.IntLiteral;
import org.apache.doris.analysis.NullLiteral;
import org.apache.doris.analysis.StringLiteral;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.NodeType;
import org.apache.doris.nereids.trees.expressions.Arithmetic;
import org.apache.doris.nereids.trees.expressions.Between;
@ -62,8 +63,23 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
public static ExpressionTranslator INSTANCE = new ExpressionTranslator();
public static Expr translate(Expression expression, PlanTranslatorContext planContext) {
return expression.accept(INSTANCE, planContext);
/**
* The entry function of ExpressionTranslator, call {@link Expr#finalizeForNereids()} to generate
* some attributes using in BE.
*
* @param expression nereids expression
* @param context translator context
* @return stale planner's expr
*/
public static Expr translate(Expression expression, PlanTranslatorContext context) {
Expr staleExpr = expression.accept(INSTANCE, context);
try {
staleExpr.finalizeForNereids();
} catch (org.apache.doris.common.AnalysisException e) {
throw new AnalysisException(
"Translate Nereids expression to stale expression failed. " + e.getMessage(), e);
}
return staleExpr;
}
@Override
@ -159,7 +175,7 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
@Override
public Expr visitCompoundPredicate(CompoundPredicate compoundPredicate, PlanTranslatorContext context) {
NodeType nodeType = compoundPredicate.getType();
org.apache.doris.analysis.CompoundPredicate.Operator staleOp = null;
org.apache.doris.analysis.CompoundPredicate.Operator staleOp;
switch (nodeType) {
case OR:
staleOp = org.apache.doris.analysis.CompoundPredicate.Operator.OR;
@ -171,7 +187,7 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
staleOp = org.apache.doris.analysis.CompoundPredicate.Operator.NOT;
break;
default:
throw new RuntimeException(String.format("Unknown node type: %s", nodeType.name()));
throw new AnalysisException(String.format("Unknown node type: %s", nodeType.name()));
}
return new org.apache.doris.analysis.CompoundPredicate(staleOp,
compoundPredicate.child(0).accept(this, context),

View File

@ -21,13 +21,11 @@ import org.apache.doris.analysis.AggregateInfo;
import org.apache.doris.analysis.BaseTableRef;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.SortInfo;
import org.apache.doris.analysis.TableName;
import org.apache.doris.analysis.TableRef;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.analysis.TupleId;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Table;
import org.apache.doris.nereids.exceptions.AnalysisException;
@ -40,23 +38,27 @@ import org.apache.doris.nereids.operators.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.operators.plans.physical.PhysicalOperator;
import org.apache.doris.nereids.operators.plans.physical.PhysicalProject;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanOperatorVisitor;
import org.apache.doris.nereids.trees.plans.physical.PhysicalBinaryPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLeafPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalUnaryPlan;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.planner.AggregationNode;
import org.apache.doris.planner.CrossJoinNode;
import org.apache.doris.planner.DataPartition;
import org.apache.doris.planner.ExchangeNode;
import org.apache.doris.planner.HashJoinNode;
import org.apache.doris.planner.HashJoinNode.DistributionMode;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.planner.PlanNode;
@ -64,8 +66,6 @@ import org.apache.doris.planner.SortNode;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.HashSet;
@ -82,7 +82,23 @@ import java.util.stream.Collectors;
* </STRONG>
*/
public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, PlanTranslatorContext> {
private static final Logger LOG = LogManager.getLogger(PhysicalPlanTranslator.class);
/**
* The left and right child of origin predicates need to be swap sometimes.
* Case A:
* select * from t1 join t2 on t2.id=t1.id
* The left plan node is t1 and the right plan node is t2.
* The left child of origin predicate is t2.id and the right child of origin predicate is t1.id.
* In this situation, the children of predicate need to be swap => t1.id=t2.id.
*/
private static Expression swapEqualToForChildrenOrder(EqualTo<?, ?> equalTo, List<Slot> leftOutput) {
Set<ExprId> leftSlots = SlotExtractor.extractSlot(equalTo.left()).stream()
.map(NamedExpression::getExprId).collect(Collectors.toSet());
if (leftOutput.stream().map(NamedExpression::getExprId).collect(Collectors.toSet()).containsAll(leftSlots)) {
return equalTo;
} else {
return new EqualTo<>(equalTo.right(), equalTo.left());
}
}
public void translatePlan(PhysicalPlan physicalPlan, PlanTranslatorContext context) {
visit(physicalPlan, context);
@ -103,7 +119,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
PlanFragment inputPlanFragment = visit(agg.child(0), context);
AggregationNode aggregationNode = null;
AggregationNode aggregationNode;
List<Slot> slotList = new ArrayList<>();
PhysicalAggregation physicalAggregation = agg.getOperator();
AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec();
@ -129,7 +145,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
List<Expr> execPartitionExpressions = partitionExpressionList.stream()
.map(e -> (FunctionCallExpr) ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
// todo: support DISTINCT
AggregateInfo aggInfo = null;
AggregateInfo aggInfo;
switch (phase) {
case FIRST:
aggInfo = AggregateInfo.create(execGroupingExpressions, execAggExpressions, outputTupleDesc,
@ -240,15 +256,15 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
return mergeFragment;
}
// TODO: 1. support broadcast join / co-locate / bucket shuffle join later
// TODO: 1. support shuffle join / co-locate / bucket shuffle join later
// 2. For ssb, there are only binary equal predicate, we shall support more in the future.
@Override
public PlanFragment visitPhysicalHashJoin(
PhysicalBinaryPlan<PhysicalHashJoin, Plan, Plan> hashJoin, PlanTranslatorContext context) {
// NOTICE: We must visit from right to left, to ensure the last fragment is root fragment
PlanFragment rightFragment = visit(hashJoin.child(1), context);
PlanFragment leftFragment = visit(hashJoin.child(0), context);
PlanFragment rightFragment = visit(hashJoin.child(0), context);
PhysicalHashJoin physicalHashJoin = hashJoin.getOperator();
Expression predicateExpr = physicalHashJoin.getCondition().get();
// Expression predicateExpr = physicalHashJoin.getCondition().get();
// List<Expression> eqExprList = Utils.getEqConjuncts(hashJoin.child(0).getOutput(),
@ -259,14 +275,11 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
PlanNode rightFragmentPlanRoot = rightFragment.getPlanRoot();
if (joinType.equals(JoinType.CROSS_JOIN)
|| physicalHashJoin.getJoinType().equals(JoinType.INNER_JOIN) && false /* eqExprList.isEmpty() */) {
|| physicalHashJoin.getJoinType().equals(JoinType.INNER_JOIN)
&& !physicalHashJoin.getCondition().isPresent()) {
CrossJoinNode crossJoinNode = new CrossJoinNode(context.nextNodeId(), leftFragment.getPlanRoot(),
rightFragment.getPlanRoot(), null);
crossJoinNode.setLimit(physicalHashJoin.getLimit());
List<Expr> conjuncts = Utils.extractConjuncts(predicateExpr).stream()
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toCollection(ArrayList::new));
crossJoinNode.addConjuncts(conjuncts);
ExchangeNode exchangeNode = new ExchangeNode(context.nextNodeId(), rightFragment.getPlanRoot(), false);
exchangeNode.setNumInstances(rightFragmentPlanRoot.getNumInstances());
exchangeNode.setFragment(leftFragment);
@ -277,24 +290,22 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
context.addPlanFragment(leftFragment);
return leftFragment;
}
List<Expr> execEqConjunctList = Lists.newArrayList(ExpressionTranslator.translate(predicateExpr, context));
Expression eqJoinExpression = physicalHashJoin.getCondition().get();
List<Expr> execEqConjunctList = ExpressionUtils.extractConjunct(eqJoinExpression).stream()
.map(EqualTo.class::cast)
.map(e -> swapEqualToForChildrenOrder(e, hashJoin.left().getOutput()))
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
HashJoinNode hashJoinNode = new HashJoinNode(context.nextNodeId(), leftFragmentPlanRoot, rightFragmentPlanRoot,
JoinType.toJoinOperator(physicalHashJoin.getJoinType()), execEqConjunctList, Lists.newArrayList());
ExchangeNode leftExch = new ExchangeNode(context.nextNodeId(), leftFragmentPlanRoot, false);
leftExch.setNumInstances(leftFragmentPlanRoot.getNumInstances());
ExchangeNode rightExch = new ExchangeNode(context.nextNodeId(), leftFragmentPlanRoot, false);
rightExch.setNumInstances(rightFragmentPlanRoot.getNumInstances());
hashJoinNode.setDistributionMode(DistributionMode.BROADCAST);
hashJoinNode.setChild(0, leftFragmentPlanRoot);
hashJoinNode.setChild(1, leftFragmentPlanRoot);
hashJoinNode.setDistributionMode(HashJoinNode.DistributionMode.PARTITIONED);
hashJoinNode.setLimit(physicalHashJoin.getLimit());
leftFragment.setDestination((ExchangeNode) rightFragment.getPlanRoot());
rightFragment.setDestination((ExchangeNode) leftFragmentPlanRoot);
PlanFragment result = new PlanFragment(context.nextFragmentId(), hashJoinNode, leftFragment.getDataPartition());
context.addPlanFragment(result);
return result;
connectChildFragment(hashJoinNode, 1, leftFragment, rightFragment, context);
leftFragment.setPlanRoot(hashJoinNode);
return leftFragment;
}
// TODO: generate expression mapping when be project could do in ExecNode
@ -318,15 +329,18 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
requiredSlotIdList.add(((SlotRef) expr).getDesc().getId().asInt());
}
}
for (TupleId tupleId : inputPlanNode.getTupleIds()) {
TupleDescriptor tupleDescriptor = context.getTupleDesc(tupleId);
Preconditions.checkNotNull(tupleDescriptor);
List<SlotDescriptor> slotDescList = tupleDescriptor.getSlots();
slotDescList.removeIf(slotDescriptor -> !requiredSlotIdList.contains(slotDescriptor.getId().asInt()));
for (int i = 0; i < slotDescList.size(); i++) {
slotDescList.get(i).setSlotOffset(i);
}
}
return inputFragment;
}
@Override
public PlanFragment visitPhysicalFilter(
PhysicalUnaryPlan<PhysicalFilter, Plan> filterPlan, PlanTranslatorContext context) {
PlanFragment inputFragment = visit(filterPlan.child(0), context);
PlanNode planNode = inputFragment.getPlanRoot();
PhysicalFilter filter = filterPlan.getOperator();
Expression expression = filter.getPredicates();
List<Expression> expressionList = ExpressionUtils.extractConjunct(expression);
expressionList.stream().map(e -> ExpressionTranslator.translate(e, context)).forEach(planNode::addConjunct);
return inputFragment;
}
@ -340,20 +354,6 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
}
}
@Override
public PlanFragment visitPhysicalFilter(
PhysicalUnaryPlan<PhysicalFilter, Plan> filterPlan, PlanTranslatorContext context) {
PlanFragment inputFragment = visit(filterPlan.child(0), context);
PlanNode planNode = inputFragment.getPlanRoot();
PhysicalFilter filter = filterPlan.getOperator();
Expression expression = filter.getPredicates();
List<Expression> expressionList = Utils.extractConjuncts(expression);
expressionList.stream().map(e -> {
return ExpressionTranslator.translate(e, context);
}).forEach(planNode::addConjunct);
return inputFragment;
}
private TupleDescriptor generateTupleDesc(List<Slot> slotList, PlanTranslatorContext context, Table table) {
TupleDescriptor tupleDescriptor = context.generateTupleDesc();
tupleDescriptor.setTable(table);
@ -373,6 +373,16 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
return parentFragment;
}
private void connectChildFragment(PlanNode node, int childIdx,
PlanFragment parentFragment, PlanFragment childFragment,
PlanTranslatorContext context) {
ExchangeNode exchangeNode = new ExchangeNode(context.nextNodeId(), childFragment.getPlanRoot(), false);
exchangeNode.setNumInstances(childFragment.getPlanRoot().getNumInstances());
exchangeNode.setFragment(parentFragment);
node.setChild(childIdx, exchangeNode);
childFragment.setDestination(exchangeNode);
}
/**
* Helper function to eliminate unnecessary checked exception caught requirement from the main logic of translator.
*

View File

@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.rules.rewrite.logical.PushPredicateThroughJoin;
import com.google.common.collect.ImmutableList;
/**
* execute predicate push down job.
*/
public class PredicatePushDownRulesJob extends BatchRulesJob {
public PredicatePushDownRulesJob(PlannerContext plannerContext) {
super(plannerContext);
rulesJob.addAll(ImmutableList.of(
topDownBatch(ImmutableList.of(
new PushPredicateThroughJoin())
)));
}
}

View File

@ -17,11 +17,6 @@
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import java.util.List;
/**
* Utils for Nereids.
*/
@ -39,14 +34,4 @@ public class Utils {
return part.replace("`", "``");
}
}
// TODO: implement later
public static List<Expression> getEqConjuncts(List<Slot> left, List<Slot> right, Expression eqExpr) {
return null;
}
// TODO: implement later
public static List<Expression> extractConjuncts(Expression expr) {
return null;
}
}