[feature](nereids)Convert the expression from nereids to stale expr. (#10343)

Add ExpressionConverter.java to convert the expression from nereids to stale expression
This commit is contained in:
Kikyou1997
2022-06-25 11:16:52 +08:00
committed by GitHub
parent 7921320124
commit 3757bd521a
9 changed files with 233 additions and 46 deletions

View File

@ -29,7 +29,7 @@ import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.PhysicalPlanTranslator;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanContext;
import org.apache.doris.nereids.trees.plans.PlanTranslatorContext;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlanAdapter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
@ -65,7 +65,7 @@ public class NereidsPlanner extends Planner {
LogicalPlanAdapter logicalPlanAdapter = (LogicalPlanAdapter) queryStmt;
PhysicalPlan physicalPlan = plan(logicalPlanAdapter.getLogicalPlan(), new PhysicalProperties(), ctx);
PhysicalPlanTranslator physicalPlanTranslator = new PhysicalPlanTranslator();
PlanContext planContext = new PlanContext();
PlanTranslatorContext planContext = new PlanTranslatorContext();
physicalPlanTranslator.translatePlan(physicalPlan, planContext);
fragments = new ArrayList<>(planContext.getPlanFragmentList());
PlanFragment root = fragments.get(fragments.size() - 1);

View File

@ -58,7 +58,7 @@ public class FunctionParams {
isDistinct = false;
}
public List<Expression> getExpression() {
public List<Expression> getExpressionList() {
return expression;
}

View File

@ -18,8 +18,9 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.ArithmeticExpr;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.trees.NodeType;
import org.apache.doris.thrift.TExprOpcode;
/**
* All arithmetic operator.
@ -37,27 +38,39 @@ public class Arithmetic extends Expression {
*/
@SuppressWarnings("checkstyle:RegexpSingleline")
public enum ArithmeticOperator {
MULTIPLY("*", "multiply", Arithmetic.OperatorPosition.BINARY_INFIX, TExprOpcode.MULTIPLY),
DIVIDE("/", "divide", Arithmetic.OperatorPosition.BINARY_INFIX, TExprOpcode.DIVIDE),
MOD("%", "mod", Arithmetic.OperatorPosition.BINARY_INFIX, TExprOpcode.MOD),
ADD("+", "add", Arithmetic.OperatorPosition.BINARY_INFIX, TExprOpcode.ADD),
SUBTRACT("-", "subtract", Arithmetic.OperatorPosition.BINARY_INFIX, TExprOpcode.SUBTRACT),
MULTIPLY("*", "multiply",
Arithmetic.OperatorPosition.BINARY_INFIX, Operator.MULTIPLY),
DIVIDE("/", "divide",
Arithmetic.OperatorPosition.BINARY_INFIX, Operator.DIVIDE),
MOD("%", "mod",
Arithmetic.OperatorPosition.BINARY_INFIX, Operator.MOD),
ADD("+", "add",
Arithmetic.OperatorPosition.BINARY_INFIX, Operator.ADD),
SUBTRACT("-", "subtract",
Arithmetic.OperatorPosition.BINARY_INFIX, Operator.SUBTRACT),
//TODO: The following functions will be added later.
BITAND("&", "bitand", Arithmetic.OperatorPosition.BINARY_INFIX, TExprOpcode.BITAND),
BITOR("|", "bitor", Arithmetic.OperatorPosition.BINARY_INFIX, TExprOpcode.BITOR),
BITXOR("^", "bitxor", Arithmetic.OperatorPosition.BINARY_INFIX, TExprOpcode.BITXOR),
BITNOT("~", "bitnot", Arithmetic.OperatorPosition.UNARY_PREFIX, TExprOpcode.BITNOT);
BITAND("&", "bitand",
Arithmetic.OperatorPosition.BINARY_INFIX, Operator.BITAND),
BITOR("|", "bitor",
Arithmetic.OperatorPosition.BINARY_INFIX, Operator.BITOR),
BITXOR("^", "bitxor",
Arithmetic.OperatorPosition.BINARY_INFIX, Operator.BITXOR),
BITNOT("~", "bitnot",
Arithmetic.OperatorPosition.UNARY_PREFIX, Operator.BITNOT);
private final String description;
private final String name;
private final Arithmetic.OperatorPosition pos;
private final TExprOpcode opcode;
private final ArithmeticExpr.Operator staleOp;
ArithmeticOperator(String description, String name, Arithmetic.OperatorPosition pos, TExprOpcode opcode) {
ArithmeticOperator(String description,
String name,
Arithmetic.OperatorPosition pos,
ArithmeticExpr.Operator staleOp) {
this.description = description;
this.name = name;
this.pos = pos;
this.opcode = opcode;
this.staleOp = staleOp;
}
@Override
@ -73,8 +86,8 @@ public class Arithmetic extends Expression {
return pos;
}
public TExprOpcode getOpcode() {
return opcode;
public Operator getStaleOp() {
return staleOp;
}
public boolean isUnary() {

View File

@ -17,18 +17,164 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.ArithmeticExpr;
import org.apache.doris.analysis.BinaryPredicate;
import org.apache.doris.analysis.BinaryPredicate.Operator;
import org.apache.doris.analysis.BoolLiteral;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FloatLiteral;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.IntLiteral;
import org.apache.doris.analysis.NullLiteral;
import org.apache.doris.analysis.StringLiteral;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.trees.NodeType;
import org.apache.doris.nereids.trees.plans.PlanTranslatorContext;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.NullType;
import org.apache.doris.nereids.types.StringType;
import java.util.ArrayList;
import java.util.List;
/**
* Used to convert expression of new optimizer to stale expr.
*/
public class ExpressionConverter {
@SuppressWarnings("rawtypes")
public class ExpressionConverter extends ExpressionVisitor<Expr, PlanTranslatorContext> {
public static ExpressionConverter converter = new ExpressionConverter();
// TODO: implement this, besides if expression is a slot, should set the slotId to
// converted the org.apache.doris.analysis.Expr
public Expr convert(Expression expression) {
return null;
public static Expr convert(Expression expression, PlanTranslatorContext planContext) {
return converter.visit(expression, planContext);
}
@Override
public Expr visit(Expression expr, PlanTranslatorContext context) {
return expr.accept(this, context);
}
@Override
public Expr visitSlotReference(SlotReference slotReference, PlanTranslatorContext context) {
return context.findExpr(slotReference);
}
@Override
public Expr visitEqualTo(EqualTo equalTo, PlanTranslatorContext context) {
return new BinaryPredicate(Operator.EQ,
visit(equalTo.child(0), context),
visit(equalTo.child(1), context));
}
@Override
public Expr visitGreaterThan(GreaterThan greaterThan, PlanTranslatorContext context) {
return new BinaryPredicate(Operator.GT,
visit(greaterThan.child(0), context),
visit(greaterThan.child(1), context));
}
@Override
public Expr visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, PlanTranslatorContext context) {
return new BinaryPredicate(Operator.GE,
visit(greaterThanEqual.child(0), context),
visit(greaterThanEqual.child(1), context));
}
@Override
public Expr visitLessThan(LessThan lessThan, PlanTranslatorContext context) {
return new BinaryPredicate(Operator.LT,
visit(lessThan.child(0), context),
visit(lessThan.child(1), context));
}
@Override
public Expr visitLessThanEqual(LessThanEqual lessThanEqual, PlanTranslatorContext context) {
return new BinaryPredicate(Operator.LE,
visit(lessThanEqual.child(0), context),
visit(lessThanEqual.child(1), context));
}
@Override
public Expr visitNot(Not not, PlanTranslatorContext context) {
return new org.apache.doris.analysis.CompoundPredicate(
org.apache.doris.analysis.CompoundPredicate.Operator.NOT,
visit(not.child(0), context),
null);
}
@Override
public Expr visitNullSafeEqual(NullSafeEqual nullSafeEqual, PlanTranslatorContext context) {
return new BinaryPredicate(Operator.EQ_FOR_NULL,
visit(nullSafeEqual.child(0), context),
visit(nullSafeEqual.child(1), context));
}
/**
* Convert to stale literal.
*/
@Override
public Expr visitLiteral(Literal literal, PlanTranslatorContext context) {
DataType dataType = literal.getDataType();
if (dataType instanceof BooleanType) {
return new BoolLiteral((Boolean) literal.getValue());
} else if (dataType instanceof DoubleType) {
return new FloatLiteral((Double) literal.getValue(), Type.DOUBLE);
} else if (dataType instanceof IntegerType) {
return new IntLiteral((Long) literal.getValue());
} else if (dataType instanceof NullType) {
return new NullLiteral();
} else if (dataType instanceof StringType) {
return new StringLiteral((String) literal.getValue());
}
throw new RuntimeException(String.format("Unsupported data type: %s", dataType.toString()));
}
// TODO: Supports for `distinct`
@Override
public Expr visitFunctionCall(FunctionCall function, PlanTranslatorContext context) {
List<Expr> paramList = new ArrayList<>();
for (Expression expr : function.getFnParams().getExpressionList()) {
paramList.add(visit(expr, context));
}
return new FunctionCallExpr(function.getFnName().toString(), paramList);
}
@Override
public Expr visitBetweenPredicate(BetweenPredicate betweenPredicate, PlanTranslatorContext context) {
throw new RuntimeException("Unexpected invocation");
}
@Override
public Expr visitCompoundPredicate(CompoundPredicate compoundPredicate, PlanTranslatorContext context) {
NodeType nodeType = compoundPredicate.getType();
org.apache.doris.analysis.CompoundPredicate.Operator staleOp = null;
switch (nodeType) {
case OR:
staleOp = org.apache.doris.analysis.CompoundPredicate.Operator.OR;
break;
case AND:
staleOp = org.apache.doris.analysis.CompoundPredicate.Operator.AND;
break;
case NOT:
staleOp = org.apache.doris.analysis.CompoundPredicate.Operator.NOT;
break;
default:
throw new RuntimeException(String.format("Unknown node type: %s", nodeType.name()));
}
return new org.apache.doris.analysis.CompoundPredicate(staleOp,
visit(compoundPredicate.child(0), context),
visit(compoundPredicate.child(1), context));
}
@Override
public Expr visitArithmetic(Arithmetic arithmetic, PlanTranslatorContext context) {
Arithmetic.ArithmeticOperator arithmeticOperator = arithmetic.getArithOperator();
return new ArithmeticExpr(arithmeticOperator.getStaleOp(),
visit(arithmetic.child(0), context),
arithmeticOperator.isBinary() ? visit(arithmetic.child(1), context) : null);
}
}

View File

@ -31,7 +31,7 @@ public class FunctionCall extends Expression {
private FunctionParams fnParams;
private FunctionCall(FunctionName functionName, FunctionParams functionParams) {
super(NodeType.FUNCTIONCALL, functionParams.getExpression().toArray(new Expression[0]));
super(NodeType.FUNCTIONCALL, functionParams.getExpressionList().toArray(new Expression[0]));
this.fnName = functionName;
this.fnParams = functionParams;
}
@ -47,4 +47,12 @@ public class FunctionCall extends Expression {
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitFunctionCall(this, context);
}
public FunctionName getFnName() {
return fnName;
}
public FunctionParams getFnParams() {
return fnParams;
}
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.analysis.AggregateInfo;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.SortInfo;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.catalog.OlapTable;
@ -63,15 +64,14 @@ import java.util.stream.Collectors;
/**
* Used to translate to physical plan generated by new optimizer to the plan fragments.
*/
@SuppressWarnings("rawtypes")
public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, PlanContext> {
public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, PlanTranslatorContext> {
public void translatePlan(PhysicalPlan physicalPlan, PlanContext context) {
public void translatePlan(PhysicalPlan physicalPlan, PlanTranslatorContext context) {
visit(physicalPlan, context);
}
@Override
public PlanFragment visit(Plan plan, PlanContext context) {
public PlanFragment visit(Plan plan, PlanTranslatorContext context) {
PhysicalOperator operator = (PhysicalOperator) plan.getOperator();
return operator.accept(this, plan, context);
}
@ -81,7 +81,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
*/
@Override
public PlanFragment visitPhysicalAggregation(
PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanContext context) {
PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) {
PlanFragment inputPlanFragment = visit(agg.child(0), context);
@ -93,17 +93,17 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList();
ArrayList<Expr> execGroupingExpressions = groupByExpressionList.stream()
.map(e -> ExpressionConverter.converter.convert(e)).collect(Collectors.toCollection(ArrayList::new));
.map(e -> ExpressionConverter.convert(e, context)).collect(Collectors.toCollection(ArrayList::new));
List<Expression> aggExpressionList = physicalAggregation.getAggExprList();
// TODO: agg function could be other expr type either
ArrayList<FunctionCallExpr> execAggExpressions = aggExpressionList.stream()
.map(e -> (FunctionCallExpr) ExpressionConverter.converter.convert(e))
.map(e -> (FunctionCallExpr) ExpressionConverter.convert(e, context))
.collect(Collectors.toCollection(ArrayList::new));
List<Expression> partitionExpressionList = physicalAggregation.getPartitionExprList();
List<Expr> execPartitionExpressions = partitionExpressionList.stream()
.map(e -> (FunctionCallExpr) ExpressionConverter.converter.convert(e)).collect(Collectors.toList());
.map(e -> (FunctionCallExpr) ExpressionConverter.convert(e, context)).collect(Collectors.toList());
// todo: support DISTINCT
AggregateInfo aggInfo = null;
switch (phase) {
@ -132,7 +132,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
@Override
public PlanFragment visitPhysicalOlapScan(
PhysicalLeafPlan<PhysicalOlapScan> olapScan, PlanContext context) {
PhysicalLeafPlan<PhysicalOlapScan> olapScan, PlanTranslatorContext context) {
// Create OlapScanNode
List<Slot> slotList = olapScan.getOutput();
PhysicalOlapScan physicalOlapScan = olapScan.getOperator();
@ -148,7 +148,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
@Override
public PlanFragment visitPhysicalSort(PhysicalUnaryPlan<PhysicalSort, Plan> sort,
PlanContext context) {
PlanTranslatorContext context) {
PlanFragment childFragment = visit(sort.child(0), context);
PhysicalSort physicalSort = sort.getOperator();
if (!childFragment.isPartitioned()) {
@ -162,7 +162,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
List<OrderKey> orderKeyList = physicalSort.getOrderList();
orderKeyList.forEach(k -> {
execOrderingExprList.add(ExpressionConverter.converter.convert(k.getExpr()));
execOrderingExprList.add(ExpressionConverter.convert(k.getExpr(), context));
ascOrderList.add(k.isAsc());
nullsFirstParamList.add(k.isNullFirst());
});
@ -200,7 +200,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
// TODO: support broadcast join / co-locate / bucket shuffle join later
@Override
public PlanFragment visitPhysicalHashJoin(
PhysicalBinaryPlan<PhysicalHashJoin, Plan, Plan> hashJoin, PlanContext context) {
PhysicalBinaryPlan<PhysicalHashJoin, Plan, Plan> hashJoin, PlanTranslatorContext context) {
PlanFragment leftFragment = visit(hashJoin.child(0), context);
PlanFragment rightFragment = visit(hashJoin.child(0), context);
PhysicalHashJoin physicalHashJoin = hashJoin.getOperator();
@ -218,7 +218,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
rightFragment.getPlanRoot(), null);
crossJoinNode.setLimit(physicalHashJoin.getLimited());
List<Expr> conjuncts = Utils.extractConjuncts(predicateExpr).stream()
.map(e -> ExpressionConverter.converter.convert(e))
.map(e -> ExpressionConverter.convert(e, context))
.collect(Collectors.toCollection(ArrayList::new));
crossJoinNode.addConjuncts(conjuncts);
ExchangeNode exchangeNode = new ExchangeNode(context.nextNodeId(), rightFragment.getPlanRoot(), false);
@ -234,9 +234,9 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
List<Expression> expressionList = Utils.extractConjuncts(predicateExpr);
expressionList.removeAll(eqExprList);
List<Expr> execOtherConjunctList = expressionList.stream().map(e -> ExpressionConverter.converter.convert(e))
List<Expr> execOtherConjunctList = expressionList.stream().map(e -> ExpressionConverter.convert(e, context))
.collect(Collectors.toCollection(ArrayList::new));
List<Expr> execEqConjunctList = eqExprList.stream().map(e -> ExpressionConverter.converter.convert(e))
List<Expr> execEqConjunctList = eqExprList.stream().map(e -> ExpressionConverter.convert(e, context))
.collect(Collectors.toCollection(ArrayList::new));
HashJoinNode hashJoinNode = new HashJoinNode(context.nextNodeId(), leftFragmentPlanRoot, rightFragmentPlanRoot,
@ -259,23 +259,25 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
@Override
public PlanFragment visitPhysicalProject(
PhysicalUnaryPlan<PhysicalProject, Plan> projectPlan, PlanContext context) {
PhysicalUnaryPlan<PhysicalProject, Plan> projectPlan, PlanTranslatorContext context) {
return visit(projectPlan.child(0), context);
}
@Override
public PlanFragment visitPhysicalFilter(
PhysicalUnaryPlan<PhysicalFilter, Plan> filterPlan, PlanContext context) {
PhysicalUnaryPlan<PhysicalFilter, Plan> filterPlan, PlanTranslatorContext context) {
PlanFragment inputFragment = visit(filterPlan.child(0), context);
PlanNode planNode = inputFragment.getPlanRoot();
PhysicalFilter filter = filterPlan.getOperator();
Expression expression = filter.getPredicates();
List<Expression> expressionList = Utils.extractConjuncts(expression);
expressionList.stream().map(ExpressionConverter.converter::convert).forEach(planNode::addConjunct);
expressionList.stream().map(e -> {
return ExpressionConverter.convert(e, context);
}).forEach(planNode::addConjunct);
return inputFragment;
}
private TupleDescriptor generateTupleDesc(List<Slot> slotList, PlanContext context, Table table) {
private TupleDescriptor generateTupleDesc(List<Slot> slotList, PlanTranslatorContext context, Table table) {
TupleDescriptor tupleDescriptor = context.generateTupleDesc();
tupleDescriptor.setTable(table);
for (Slot slot : slotList) {
@ -284,12 +286,13 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
slotDescriptor.setColumn(slotReference.getColumn());
slotDescriptor.setType(slotReference.getDataType().toCatalogDataType());
slotDescriptor.setIsMaterialized(true);
context.addSlotRefMapping(slot, new SlotRef(slotDescriptor));
}
return tupleDescriptor;
}
private PlanFragment createParentFragment(PlanFragment childFragment, DataPartition parentPartition,
PlanContext ctx) {
PlanTranslatorContext ctx) {
ExchangeNode exchangeNode = new ExchangeNode(ctx.nextNodeId(), childFragment.getPlanRoot(), false);
exchangeNode.setNumInstances(childFragment.getPlanRoot().getNumInstances());
PlanFragment parentFragment = new PlanFragment(ctx.nextFragmentId(), exchangeNode, parentPartition);

View File

@ -39,7 +39,6 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalUnaryPlan;
* @param <R> Return type of each visit method.
* @param <C> Context type.
*/
@SuppressWarnings("rawtypes")
public abstract class PlanOperatorVisitor<R, C> {
public abstract R visit(Plan plan, C context);

View File

@ -18,9 +18,11 @@
package org.apache.doris.nereids.trees.plans;
import org.apache.doris.analysis.DescriptorTable;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.planner.PlanFragmentId;
import org.apache.doris.planner.PlanNodeId;
@ -29,16 +31,23 @@ import org.apache.doris.planner.ScanNode;
import com.clearspring.analytics.util.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Context of physical plan.
*/
public class PlanContext {
public class PlanTranslatorContext {
private final List<PlanFragment> planFragmentList = Lists.newArrayList();
private final DescriptorTable descTable = new DescriptorTable();
/**
* Map expressions of new optimizer to the stale expr.
*/
private Map<Expression, Expr> expressionToExecExpr = new HashMap<>();
private final List<ScanNode> scanNodeList = new ArrayList<>();
private final IdGenerator<PlanFragmentId> fragmentIdGenerator = PlanFragmentId.createGenerator();
@ -73,6 +82,14 @@ public class PlanContext {
this.planFragmentList.add(planFragment);
}
public void addSlotRefMapping(Expression expression, Expr expr) {
expressionToExecExpr.put(expression, expr);
}
public Expr findExpr(Expression expression) {
return expressionToExecExpr.get(expression);
}
public void addScanNode(ScanNode scanNode) {
scanNodeList.add(scanNode);
}

View File

@ -61,4 +61,5 @@ public abstract class DataType {
}
public abstract Type toCatalogDataType();
}