[Feature](Prepared Statment) Implement in nereids planner (#35318) (#36172)

This commit is contained in:
lihangyu
2024-06-12 19:54:17 +08:00
committed by GitHub
parent 0b28420e1c
commit 9708ca8fcb
22 changed files with 954 additions and 247 deletions

View File

@ -29,10 +29,12 @@ import org.apache.doris.nereids.rules.analysis.ColumnAliasGenerator;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.TableId;
import org.apache.doris.nereids.trees.plans.algebra.Relation;
@ -118,6 +120,11 @@ public class StatementContext implements Closeable {
private final Set<String> viewDdlSqlSet = Sets.newHashSet();
private final SqlCacheContext sqlCacheContext;
// generate for next id for prepared statement's placeholders, which is connection level
private final IdGenerator<PlaceholderId> placeHolderIdGenerator = PlaceholderId.createGenerator();
// relation id to placeholders for prepared statement
private final Map<PlaceholderId, Expression> idToPlaceholderRealExpr = new HashMap<>();
// collect all hash join conditions to compute node connectivity in join graph
private final List<Expression> joinFilters = new ArrayList<>();
@ -141,6 +148,9 @@ public class StatementContext implements Closeable {
// table locks
private final Stack<CloseableResource> plannerResources = new Stack<>();
// placeholder params for prepared statement
private List<Placeholder> placeholders;
// for create view support in nereids
// key is the start and end position of the sql substring that needs to be replaced,
// and value is the new string used for replacement.
@ -367,6 +377,14 @@ public class StatementContext implements Closeable {
return consumerIdToFilters;
}
public PlaceholderId getNextPlaceholderId() {
return placeHolderIdGenerator.getNextId();
}
public Map<PlaceholderId, Expression> getIdToPlaceholderRealExpr() {
return idToPlaceholderRealExpr;
}
public Map<CTEId, List<Pair<Map<Slot, Slot>, Group>>> getCteIdToConsumerGroup() {
return cteIdToConsumerGroup;
}
@ -487,6 +505,14 @@ public class StatementContext implements Closeable {
releasePlannerResources();
}
public List<Placeholder> getPlaceholders() {
return placeholders;
}
public void setPlaceholders(List<Placeholder> placeholders) {
this.placeholders = placeholders;
}
private static class CloseableResource implements Closeable {
public final String resourceName;
public final String threadName;

View File

@ -260,6 +260,7 @@ import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.Properties;
import org.apache.doris.nereids.trees.expressions.Regexp;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
@ -487,6 +488,16 @@ import java.util.stream.Collectors;
public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
private final boolean forCreateView;
// Sort the parameters with token position to keep the order with original placeholders
// in prepared statement.Otherwise, the order maybe broken
private final Map<Token, Placeholder> tokenPosToParameters = Maps.newTreeMap((pos1, pos2) -> {
int line = pos1.getLine() - pos2.getLine();
if (line != 0) {
return line;
}
return pos1.getCharPositionInLine() - pos2.getCharPositionInLine();
});
public LogicalPlanBuilder() {
forCreateView = false;
}
@ -1003,6 +1014,9 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
}
logicalPlans.add(Pair.of(
ParserUtils.withOrigin(ctx, () -> (LogicalPlan) visit(statement)), statementContext));
List<Placeholder> params = new ArrayList<>(tokenPosToParameters.values());
statementContext.setPlaceholders(params);
tokenPosToParameters.clear();
}
return logicalPlans;
}
@ -2313,6 +2327,13 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
return new VarcharLiteral(s, strLength);
}
@Override
public Expression visitPlaceholder(DorisParser.PlaceholderContext ctx) {
Placeholder parameter = new Placeholder(ConnectContext.get().getStatementContext().getNextPlaceholderId());
tokenPosToParameters.put(ctx.start, parameter);
return parameter;
}
/**
* cast all items to same types.
* TODO remove this function after we refactor type coercion.

View File

@ -56,6 +56,7 @@ import org.apache.doris.nereids.trees.expressions.ListQuery;
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
@ -541,6 +542,13 @@ public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext
return expr;
}
@Override
public Expression visitPlaceholder(Placeholder placeholder, ExpressionRewriteContext context) {
Expression realExpr = context.cascadesContext.getStatementContext()
.getIdToPlaceholderRealExpr().get(placeholder.getPlaceholderId());
return visit(realExpr, context);
}
@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) {
Expression left = cp.left().accept(this, context);

View File

@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.catalog.MysqlColType;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.shape.LeafExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NullType;
import java.util.Optional;
/**
* Placeholder for prepared statement
*/
public class Placeholder extends Expression implements LeafExpression {
private final PlaceholderId placeholderId;
private final Optional<MysqlColType> mysqlColType;
public Placeholder(PlaceholderId placeholderId) {
this.placeholderId = placeholderId;
this.mysqlColType = Optional.empty();
}
public Placeholder(PlaceholderId placeholderId, MysqlColType mysqlColType) {
this.placeholderId = placeholderId;
this.mysqlColType = Optional.of(mysqlColType);
}
public PlaceholderId getPlaceholderId() {
return placeholderId;
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitPlaceholder(this, context);
}
@Override
public boolean nullable() {
return true;
}
@Override
public String toSql() {
return "?";
}
@Override
public DataType getDataType() throws UnboundException {
return NullType.INSTANCE;
}
public Placeholder withNewMysqlColType(MysqlColType mysqlColType) {
return new Placeholder(getPlaceholderId(), mysqlColType);
}
public MysqlColType getMysqlColType() {
return mysqlColType.get();
}
}

View File

@ -19,8 +19,10 @@ package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.BoolLiteral;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.catalog.MysqlColType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Config;
import org.apache.doris.mysql.MysqlProto;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -28,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.shape.LeafExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
@ -40,6 +43,7 @@ import com.google.common.collect.ImmutableList;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
@ -396,4 +400,169 @@ public abstract class Literal extends Expression implements LeafExpression, Comp
}
return false;
}
/**
** get paramter length, port from mysql get_param_length
**/
public static int getParmLen(ByteBuffer data) {
int maxLen = data.remaining();
if (maxLen < 1) {
return 0;
}
// get and advance 1 byte
int len = MysqlProto.readInt1(data);
if (len == 252) {
if (maxLen < 3) {
return 0;
}
// get and advance 2 bytes
return MysqlProto.readInt2(data);
} else if (len == 253) {
if (maxLen < 4) {
return 0;
}
// get and advance 3 bytes
return MysqlProto.readInt3(data);
} else if (len == 254) {
/*
In our client-server protocol all numbers bigger than 2^24
stored as 8 bytes with uint8korr. Here we always know that
parameter length is less than 2^4 so we don't look at the second
4 bytes. But still we need to obey the protocol hence 9 in the
assignment below.
*/
if (maxLen < 9) {
return 0;
}
len = MysqlProto.readInt4(data);
MysqlProto.readFixedString(data, 4);
return len;
} else if (len == 255) {
return 0;
} else {
return len;
}
}
/**
* Retrieves a Literal object based on the MySQL type and the data provided.
*
* @param mysqlType the MySQL type identifier
* @param data the ByteBuffer containing the data
* @return a Literal object corresponding to the MySQL type
* @throws AnalysisException if the MySQL type is unsupported or if data conversion fails
* @link <a href="https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html">...</a>.
*/
public static Literal getLiteralByMysqlType(MysqlColType mysqlType, ByteBuffer data) throws AnalysisException {
switch (mysqlType) {
case MYSQL_TYPE_TINY:
return new TinyIntLiteral(data.get());
case MYSQL_TYPE_SHORT:
return new SmallIntLiteral((short) data.getChar());
case MYSQL_TYPE_LONG:
return new IntegerLiteral(data.getInt());
case MYSQL_TYPE_LONGLONG:
return new BigIntLiteral(data.getLong());
case MYSQL_TYPE_FLOAT:
return new FloatLiteral(data.getFloat());
case MYSQL_TYPE_DOUBLE:
return new DoubleLiteral(data.getDouble());
case MYSQL_TYPE_DECIMAL:
case MYSQL_TYPE_NEWDECIMAL:
return handleDecimalLiteral(data);
case MYSQL_TYPE_DATE:
return handleDateLiteral(data);
case MYSQL_TYPE_DATETIME:
case MYSQL_TYPE_TIMESTAMP:
case MYSQL_TYPE_TIMESTAMP2:
return handleDateTimeLiteral(data);
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VARSTRING:
return handleStringLiteral(data);
case MYSQL_TYPE_VARCHAR:
return handleVarcharLiteral(data);
default:
throw new AnalysisException("Unsupported MySQL type: " + mysqlType);
}
}
private static Literal handleDecimalLiteral(ByteBuffer data) throws AnalysisException {
int len = getParmLen(data);
byte[] bytes = new byte[len];
data.get(bytes);
try {
String value = new String(bytes);
BigDecimal v = new BigDecimal(value);
if (Config.enable_decimal_conversion) {
return new DecimalV3Literal(v);
}
return new DecimalLiteral(v);
} catch (NumberFormatException e) {
throw new AnalysisException("Invalid decimal literal", e);
}
}
private static Literal handleDateLiteral(ByteBuffer data) {
int len = getParmLen(data);
if (len >= 4) {
int year = (int) data.getChar();
int month = (int) data.get();
int day = (int) data.get();
if (Config.enable_date_conversion) {
return new DateV2Literal(year, month, day);
}
return new DateLiteral(year, month, day);
} else {
if (Config.enable_date_conversion) {
return new DateV2Literal(0, 1, 1);
}
return new DateLiteral(0, 1, 1);
}
}
private static Literal handleDateTimeLiteral(ByteBuffer data) {
int len = getParmLen(data);
if (len >= 4) {
int year = (int) data.getChar();
int month = (int) data.get();
int day = (int) data.get();
int hour = 0;
int minute = 0;
int second = 0;
int microsecond = 0;
if (len > 4) {
hour = (int) data.get();
minute = (int) data.get();
second = (int) data.get();
}
if (len > 7) {
microsecond = data.getInt();
}
if (Config.enable_date_conversion) {
return new DateTimeV2Literal(year, month, day, hour, minute, second, microsecond);
}
return new DateTimeLiteral(DateTimeType.INSTANCE, year, month, day, hour, minute, second, microsecond);
} else {
if (Config.enable_date_conversion) {
return new DateTimeV2Literal(0, 1, 1, 0, 0, 0);
}
return new DateTimeLiteral(0, 1, 1, 0, 0, 0);
}
}
private static Literal handleStringLiteral(ByteBuffer data) {
int strLen = getParmLen(data);
strLen = Math.min(strLen, data.remaining());
byte[] bytes = new byte[strLen];
data.get(bytes);
return new StringLiteral(new String(bytes));
}
private static Literal handleVarcharLiteral(ByteBuffer data) {
int strLen = getParmLen(data);
strLen = Math.min(strLen, data.remaining());
byte[] bytes = new byte[strLen];
data.get(bytes);
return new VarcharLiteral(new String(bytes));
}
}

View File

@ -69,6 +69,7 @@ import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.Properties;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -505,6 +506,10 @@ public abstract class ExpressionVisitor<R, C>
return visitMatch(matchPhraseEdge, context);
}
public R visitPlaceholder(Placeholder placeholder, C context) {
return visit(placeholder, context);
}
public R visitAny(Any any, C context) {
return visit(any, context);
}

View File

@ -0,0 +1,58 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans;
import org.apache.doris.common.Id;
import org.apache.doris.common.IdGenerator;
/**
* placeholder id for prepared statement parameters
*/
public class PlaceholderId extends Id<PlaceholderId> {
public PlaceholderId(int id) {
super(id);
}
/**
* Should be only called by {@link org.apache.doris.nereids.StatementContext}.
*/
public static IdGenerator<PlaceholderId> createGenerator() {
return new IdGenerator<PlaceholderId>() {
@Override
public PlaceholderId getNextId() {
return new PlaceholderId(nextId++);
}
};
}
@Override
public String toString() {
return "PlaceholderId#" + id;
}
@Override
public boolean equals(Object obj) {
return super.equals(obj);
}
@Override
public int hashCode() {
return super.hashCode();
}
}

View File

@ -156,5 +156,8 @@ public enum PlanType {
CREATE_VIEW_COMMAND,
ALTER_VIEW_COMMAND,
UNSUPPORTED_COMMAND,
CREATE_TABLE_LIKE_COMMAND
CREATE_TABLE_LIKE_COMMAND,
PREPARED_COMMAND,
EXECUTE_COMMAND
}

View File

@ -0,0 +1,83 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.commands;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.PreparedStatementContext;
import org.apache.doris.qe.StmtExecutor;
import java.util.List;
import java.util.stream.Collectors;
/**
* Prepared Statement
*/
public class ExecuteCommand extends Command {
private final String stmtName;
private final PrepareCommand prepareCommand;
private final StatementContext statementContext;
public ExecuteCommand(String stmtName, PrepareCommand prepareCommand, StatementContext statementContext) {
super(PlanType.EXECUTE_COMMAND);
this.stmtName = stmtName;
this.prepareCommand = prepareCommand;
this.statementContext = statementContext;
}
public String getStmtName() {
return stmtName;
}
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visit(this, context);
}
@Override
public void run(ConnectContext ctx, StmtExecutor executor) throws Exception {
PreparedStatementContext preparedStmtCtx = ctx.getPreparedStementContext(stmtName);
if (null == preparedStmtCtx) {
throw new AnalysisException(
"prepare statement " + stmtName + " not found, maybe expired");
}
PrepareCommand prepareCommand = (PrepareCommand) preparedStmtCtx.command;
LogicalPlanAdapter planAdapter = new LogicalPlanAdapter(prepareCommand.getLogicalPlan(), executor.getContext()
.getStatementContext());
executor.setParsedStmt(planAdapter);
// execute real statement
executor.execute();
}
/**
* return the sql representation contains real expr instead of placeholders
*/
public String toSql() {
// maybe slow
List<Expression> realValueExpr = prepareCommand.getPlaceholders().stream()
.map(placeholder -> statementContext.getIdToPlaceholderRealExpr().get(placeholder.getPlaceholderId()))
.collect(Collectors.toList());
return "EXECUTE `" + stmtName + "`"
+ realValueExpr.stream().map(Expression::toSql).collect(Collectors.joining(", ", " USING ", ""));
}
}

View File

@ -0,0 +1,120 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.commands;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
import org.apache.doris.qe.PreparedStatementContext;
import org.apache.doris.qe.StmtExecutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.List;
/**
* Prepared Statement
*/
public class PrepareCommand extends Command {
private static final Logger LOG = LogManager.getLogger(StmtExecutor.class);
private final List<Placeholder> placeholders = new ArrayList<>();
private final LogicalPlan logicalPlan;
private final String name;
private final OriginStatement originalStmt;
/**
* constructor
* @param name the statement name which represents statement id for prepared statement
* @param plan the inner statement
* @param placeholders the parameters for this prepared statement
* @param originalStmt original statement from StmtExecutor
*/
public PrepareCommand(String name, LogicalPlan plan, List<Placeholder> placeholders,
OriginStatement originalStmt) {
super(PlanType.PREPARED_COMMAND);
this.logicalPlan = plan;
this.placeholders.addAll(placeholders);
this.name = name;
this.originalStmt = originalStmt;
}
public String getName() {
return name;
}
public List<Placeholder> getPlaceholders() {
return placeholders;
}
public int placeholderCount() {
return placeholders.size();
}
public LogicalPlan getLogicalPlan() {
return logicalPlan;
}
public OriginStatement getOriginalStmt() {
return originalStmt;
}
/**
* return the labels of parameters
*/
public List<String> getLabels() {
List<String> labels = new ArrayList<>();
for (Placeholder parameter : placeholders) {
labels.add("$" + parameter.getPlaceholderId().asInt());
}
return labels;
}
// register prepared statement with attached statement id
@Override
public void run(ConnectContext ctx, StmtExecutor executor) throws Exception {
List<String> labels = getLabels();
// register prepareStmt
if (LOG.isDebugEnabled()) {
LOG.debug("add prepared statement {}, isBinaryProtocol {}",
name, ctx.getCommand() == MysqlCommand.COM_STMT_PREPARE);
}
ctx.addPreparedStatementContext(name,
new PreparedStatementContext(this, ctx, ctx.getStatementContext(), name));
if (ctx.getCommand() == MysqlCommand.COM_STMT_PREPARE) {
executor.sendStmtPrepareOK((int) ctx.getStmtId(), labels);
}
}
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visit(this, context);
}
public PrepareCommand withPlaceholders(List<Placeholder> placeholders) {
return new PrepareCommand(this.name, this.logicalPlan, placeholders, this.originalStmt);
}
}

View File

@ -246,8 +246,13 @@ public class ConnectContext {
}
private StatementContext statementContext;
// legacy planner
private Map<String, PrepareStmtContext> preparedStmtCtxs = Maps.newHashMap();
// new planner
private Map<String, PreparedStatementContext> preparedStatementContextMap = Maps.newHashMap();
private List<TableIf> tables = null;
private Map<String, ColumnStatistic> totalColumnStatisticMap = new HashMap<>();
@ -384,6 +389,10 @@ public class ConnectContext {
this.preparedStmtCtxs.put(stmtName, ctx);
}
public void addPreparedStatementContext(String stmtName, PreparedStatementContext ctx) {
this.preparedStatementContextMap.put(stmtName, ctx);
}
public void removePrepareStmt(String stmtName) {
this.preparedStmtCtxs.remove(stmtName);
}
@ -392,6 +401,10 @@ public class ConnectContext {
return this.preparedStmtCtxs.get(stmtName);
}
public PreparedStatementContext getPreparedStementContext(String stmtName) {
return this.preparedStatementContextMap.get(stmtName);
}
public List<TableIf> getTables() {
return tables;
}

View File

@ -222,8 +222,11 @@ public abstract class ConnectProcessor {
Exception nereidsParseException = null;
long parseSqlStartTime = System.currentTimeMillis();
List<StatementBase> cachedStmts = null;
// Nereids do not support prepare and execute now, so forbid prepare command, only process query command
if (mysqlCommand == MysqlCommand.COM_QUERY && sessionVariable.isEnableNereidsPlanner()) {
// Currently we add a config to decide whether using PREPARED/EXECUTE command for nereids
// TODO: after implemented full prepared, we could remove this flag
boolean nereidsUseServerPrep = sessionVariable.enableServeSidePreparedStatement
|| mysqlCommand == MysqlCommand.COM_QUERY;
if (nereidsUseServerPrep && sessionVariable.isEnableNereidsPlanner()) {
if (wantToParseSqlFromSqlCache) {
cachedStmts = parseFromSqlCache(originStmt);
if (cachedStmts != null) {

View File

@ -18,15 +18,24 @@
package org.apache.doris.qe;
import org.apache.doris.analysis.ExecuteStmt;
import org.apache.doris.analysis.InsertStmt;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.analysis.NullLiteral;
import org.apache.doris.analysis.PrepareStmt;
import org.apache.doris.analysis.QueryStmt;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.catalog.MysqlColType;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.ErrorReport;
import org.apache.doris.mysql.MysqlChannel;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.mysql.MysqlProto;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.nereids.trees.plans.commands.ExecuteCommand;
import org.apache.doris.nereids.trees.plans.commands.PrepareCommand;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -81,35 +90,12 @@ public class MysqlConnectProcessor extends ConnectProcessor {
}
}
// process COM_EXECUTE, parse binary row data
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
private void handleExecute() {
// debugPacket();
packetBuf = packetBuf.order(ByteOrder.LITTLE_ENDIAN);
// parse stmt_id, flags, params
int stmtId = packetBuf.getInt();
// flag
packetBuf.get();
// iteration_count always 1,
packetBuf.getInt();
if (LOG.isDebugEnabled()) {
LOG.debug("execute prepared statement {}", stmtId);
}
PrepareStmtContext prepareCtx = ctx.getPreparedStmt(String.valueOf(stmtId));
if (prepareCtx == null) {
if (LOG.isDebugEnabled()) {
LOG.debug("No such statement in context, stmtId:{}", stmtId);
}
ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR,
"msg: Not supported such prepared statement");
return;
}
ctx.setStartTime();
if (prepareCtx.stmt.getInnerStmt() instanceof QueryStmt) {
private void handleExecute(PrepareStmt prepareStmt, long stmtId) {
if (prepareStmt.getInnerStmt() instanceof QueryStmt) {
ctx.getState().setIsQuery(true);
}
prepareCtx.stmt.setIsPrepared();
int paramCount = prepareCtx.stmt.getParmCount();
prepareStmt.setIsPrepared();
int paramCount = prepareStmt.getParmCount();
LOG.debug("execute prepared statement {}, paramCount {}", stmtId, paramCount);
// null bitmap
String stmtStr = "";
@ -124,7 +110,7 @@ public class MysqlConnectProcessor extends ConnectProcessor {
for (int i = 0; i < paramCount; ++i) {
int typeCode = packetBuf.getChar();
LOG.debug("code {}", typeCode);
prepareCtx.stmt.placeholders().get(i).setTypeCode(typeCode);
prepareStmt.placeholders().get(i).setTypeCode(typeCode);
}
}
// parse param data
@ -133,7 +119,7 @@ public class MysqlConnectProcessor extends ConnectProcessor {
realValueExprs.add(new NullLiteral());
continue;
}
LiteralExpr l = prepareCtx.stmt.placeholders().get(i).createLiteralFromType();
LiteralExpr l = prepareStmt.placeholders().get(i).createLiteralFromType();
l.setupParamFromBinary(packetBuf);
realValueExprs.add(l);
}
@ -149,7 +135,7 @@ public class MysqlConnectProcessor extends ConnectProcessor {
ctx.setExecutor(executor);
executor.execute();
PrepareStmtContext preparedStmtContext = ConnectContext.get().getPreparedStmt(String.valueOf(stmtId));
if (preparedStmtContext != null && !(preparedStmtContext.stmt.getInnerStmt() instanceof InsertStmt)) {
if (preparedStmtContext != null) {
stmtStr = executeStmt.toSql();
}
} catch (Throwable e) {
@ -159,8 +145,101 @@ public class MysqlConnectProcessor extends ConnectProcessor {
ctx.getState().setError(ErrorCode.ERR_UNKNOWN_ERROR,
e.getClass().getSimpleName() + ", msg: " + e.getMessage());
}
if (!stmtStr.isEmpty()) {
auditAfterExec(stmtStr, prepareCtx.stmt.getInnerStmt(), null, false);
auditAfterExec(stmtStr, executor.getParsedStmt(), executor.getQueryStatisticsForAuditLog(), true);
}
private void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedStatementContext prepCtx) {
int paramCount = prepareCommand.placeholderCount();
LOG.debug("execute prepared statement {}, paramCount {}", stmtId, paramCount);
// null bitmap
String stmtStr = "";
try {
StatementContext statementContext = prepCtx.statementContext;
if (paramCount > 0) {
byte[] nullbitmapData = new byte[(paramCount + 7) / 8];
packetBuf.get(nullbitmapData);
// new_params_bind_flag
if ((int) packetBuf.get() != 0) {
List<Placeholder> typedPlaceholders = new ArrayList<>();
// parse params's types
for (int i = 0; i < paramCount; ++i) {
int typeCode = packetBuf.getChar();
LOG.debug("code {}", typeCode);
// assign type to placeholders
typedPlaceholders.add(
prepareCommand.getPlaceholders().get(i)
.withNewMysqlColType(MysqlColType.fromCode(typeCode)));
}
// rewrite with new prepared statment with type info in placeholders
prepCtx.command = prepareCommand.withPlaceholders(typedPlaceholders);
prepareCommand = (PrepareCommand) prepCtx.command;
}
// parse param data
for (int i = 0; i < paramCount; ++i) {
PlaceholderId exprId = prepareCommand.getPlaceholders().get(i).getPlaceholderId();
if (isNull(nullbitmapData, i)) {
statementContext.getIdToPlaceholderRealExpr().put(exprId,
new org.apache.doris.nereids.trees.expressions.literal.NullLiteral());
continue;
}
MysqlColType type = prepareCommand.getPlaceholders().get(i).getMysqlColType();
Literal l = Literal.getLiteralByMysqlType(type, packetBuf);
statementContext.getIdToPlaceholderRealExpr().put(exprId, l);
}
}
ExecuteCommand executeStmt = new ExecuteCommand(String.valueOf(stmtId), prepareCommand, statementContext);
// TODO set real origin statement
if (LOG.isDebugEnabled()) {
LOG.debug("executeStmt {}", executeStmt);
}
StatementBase stmt = new LogicalPlanAdapter(executeStmt, statementContext);
stmt.setOrigStmt(prepareCommand.getOriginalStmt());
executor = new StmtExecutor(ctx, stmt);
ctx.setExecutor(executor);
executor.execute();
stmtStr = executeStmt.toSql();
} catch (Throwable e) {
// Catch all throwable.
// If reach here, maybe doris bug.
LOG.warn("Process one query failed because unknown reason: ", e);
ctx.getState().setError(ErrorCode.ERR_UNKNOWN_ERROR,
e.getClass().getSimpleName() + ", msg: " + e.getMessage());
}
auditAfterExec(stmtStr, executor.getParsedStmt(), executor.getQueryStatisticsForAuditLog(), true);
}
// process COM_EXECUTE, parse binary row data
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
private void handleExecute() {
// debugPacket();
packetBuf = packetBuf.order(ByteOrder.LITTLE_ENDIAN);
// parse stmt_id, flags, params
int stmtId = packetBuf.getInt();
// flag
packetBuf.get();
// iteration_count always 1,
packetBuf.getInt();
if (LOG.isDebugEnabled()) {
LOG.debug("execute prepared statement {}", stmtId);
}
PrepareStmtContext prepareCtx = ctx.getPreparedStmt(String.valueOf(stmtId));
ctx.setStartTime();
if (prepareCtx != null) {
// get from lagacy planner context, to be removed
handleExecute((PrepareStmt) prepareCtx.stmt, stmtId);
} else {
// nererids
PreparedStatementContext preparedStatementContext = ctx.getPreparedStementContext(String.valueOf(stmtId));
if (preparedStatementContext == null) {
if (LOG.isDebugEnabled()) {
LOG.debug("No such statement in context, stmtId:{}", stmtId);
}
ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR,
"msg: Not supported such prepared statement");
return;
}
handleExecute(preparedStatementContext.command, stmtId, preparedStatementContext);
}
}

View File

@ -18,17 +18,15 @@
package org.apache.doris.qe;
import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.PrepareStmt;
import org.apache.doris.planner.OriginalPlanner;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.planner.Planner;
import com.google.common.base.Preconditions;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
public class PrepareStmtContext {
private static final Logger LOG = LogManager.getLogger(PrepareStmtContext.class);
public PrepareStmt stmt;
public StatementBase stmt;
public ConnectContext ctx;
public Planner planner;
public Analyzer analyzer;
@ -37,15 +35,11 @@ public class PrepareStmtContext {
// Timestamp in millisecond last command starts at
protected volatile long startTime;
public PrepareStmtContext(PrepareStmt stmt, ConnectContext ctx, Planner planner,
public PrepareStmtContext(StatementBase stmt, ConnectContext ctx, Planner planner,
Analyzer analyzer, String stmtString) {
this.stmt = stmt;
this.ctx = ctx;
this.planner = planner;
// Only support OriginalPlanner for now
if (planner != null) {
Preconditions.checkState(planner instanceof OriginalPlanner);
}
this.analyzer = analyzer;
this.stmtString = stmtString;
}

View File

@ -0,0 +1,47 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.qe;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.trees.plans.commands.PrepareCommand;
public class PreparedStatementContext {
public PrepareCommand command;
public ConnectContext ctx;
StatementContext statementContext;
public String stmtString;
// Timestamp in millisecond last command starts at
protected volatile long startTime;
public PreparedStatementContext(PrepareCommand command,
ConnectContext ctx, StatementContext statementContext, String stmtString) {
this.command = command;
this.ctx = ctx;
this.statementContext = statementContext;
this.stmtString = stmtString;
}
public long getStartTime() {
return startTime;
}
public void setStartTime() {
startTime = System.currentTimeMillis();
}
}

View File

@ -141,6 +141,7 @@ import org.apache.doris.nereids.trees.plans.commands.Command;
import org.apache.doris.nereids.trees.plans.commands.CreateTableCommand;
import org.apache.doris.nereids.trees.plans.commands.Forward;
import org.apache.doris.nereids.trees.plans.commands.NotAllowFallback;
import org.apache.doris.nereids.trees.plans.commands.PrepareCommand;
import org.apache.doris.nereids.trees.plans.commands.insert.BatchInsertIntoTableCommand;
import org.apache.doris.nereids.trees.plans.commands.insert.InsertIntoTableCommand;
import org.apache.doris.nereids.trees.plans.commands.insert.InsertOverwriteTableCommand;
@ -261,7 +262,7 @@ public class StmtExecutor {
private Data.PQueryStatistics.Builder statisticsForAuditLog;
private boolean isCached;
private String stmtName;
private PrepareStmt prepareStmt = null;
private StatementBase prepareStmt = null;
private String mysqlLoadId;
// Distinguish from prepare and execute command
private boolean isExecuteStmt = false;
@ -644,6 +645,14 @@ public class StmtExecutor {
"Nereids only process LogicalPlanAdapter, but parsedStmt is " + parsedStmt.getClass().getName());
context.getState().setNereids(true);
LogicalPlan logicalPlan = ((LogicalPlanAdapter) parsedStmt).getLogicalPlan();
if (context.getCommand() == MysqlCommand.COM_STMT_PREPARE) {
if (isForwardToMaster()) {
throw new UserException("Forward master command is not supported for prepare statement");
}
logicalPlan = new PrepareCommand(String.valueOf(context.getStmtId()),
logicalPlan, statementContext.getPlaceholders(), originStmt);
}
// when we in transaction mode, we only support insert into command and transaction command
if (context.isTxnModel()) {
if (!(logicalPlan instanceof BatchInsertIntoTableCommand
@ -1075,8 +1084,8 @@ public class StmtExecutor {
throw new UserException("Could not execute, since `" + execStmt.getName() + "` not exist");
}
// parsedStmt may already by set when constructing this StmtExecutor();
preparedStmtCtx.stmt.asignValues(execStmt.getArgs());
parsedStmt = preparedStmtCtx.stmt.getInnerStmt();
((PrepareStmt) preparedStmtCtx.stmt).asignValues(execStmt.getArgs());
parsedStmt = ((PrepareStmt) preparedStmtCtx.stmt).getInnerStmt();
planner = preparedStmtCtx.planner;
analyzer = preparedStmtCtx.analyzer;
prepareStmt = preparedStmtCtx.stmt;
@ -1084,7 +1093,7 @@ public class StmtExecutor {
LOG.debug("already prepared stmt: {}", preparedStmtCtx.stmtString);
}
isExecuteStmt = true;
if (!preparedStmtCtx.stmt.needReAnalyze()) {
if (!((PrepareStmt) preparedStmtCtx.stmt).needReAnalyze()) {
// Return directly to bypass analyze and plan
return;
}
@ -1106,15 +1115,15 @@ public class StmtExecutor {
if (parsedStmt instanceof PrepareStmt || context.getCommand() == MysqlCommand.COM_STMT_PREPARE) {
if (context.getCommand() == MysqlCommand.COM_STMT_PREPARE) {
prepareStmt = new PrepareStmt(parsedStmt,
String.valueOf(context.getEnv().getNextStmtId()));
String.valueOf(String.valueOf(context.getStmtId())));
} else {
prepareStmt = (PrepareStmt) parsedStmt;
}
prepareStmt.setContext(context);
((PrepareStmt) prepareStmt).setContext(context);
prepareStmt.analyze(analyzer);
// Need analyze inner statement
parsedStmt = prepareStmt.getInnerStmt();
if (prepareStmt.getPreparedType() == PrepareStmt.PreparedType.STATEMENT) {
parsedStmt = ((PrepareStmt) prepareStmt).getInnerStmt();
if (((PrepareStmt) prepareStmt).getPreparedType() == PrepareStmt.PreparedType.STATEMENT) {
// Skip analyze, do it lazy
return;
}
@ -1207,15 +1216,15 @@ public class StmtExecutor {
}
}
if (preparedStmtReanalyzed
&& preparedStmtCtx.stmt.getPreparedType() == PrepareStmt.PreparedType.FULL_PREPARED) {
prepareStmt.asignValues(execStmt.getArgs());
&& ((PrepareStmt) preparedStmtCtx.stmt).getPreparedType() == PrepareStmt.PreparedType.FULL_PREPARED) {
((PrepareStmt) prepareStmt).asignValues(execStmt.getArgs());
if (LOG.isDebugEnabled()) {
LOG.debug("update planner and analyzer after prepared statement reanalyzed");
}
preparedStmtCtx.planner = planner;
preparedStmtCtx.analyzer = analyzer;
Preconditions.checkNotNull(preparedStmtCtx.stmt);
preparedStmtCtx.analyzer.setPrepareStmt(preparedStmtCtx.stmt);
preparedStmtCtx.analyzer.setPrepareStmt(((PrepareStmt) preparedStmtCtx.stmt));
}
}
@ -1273,9 +1282,9 @@ public class StmtExecutor {
}
}
if (prepareStmt != null) {
analyzer.setPrepareStmt(prepareStmt);
if (execStmt != null && prepareStmt.getPreparedType() != PreparedType.FULL_PREPARED) {
prepareStmt.asignValues(execStmt.getArgs());
analyzer.setPrepareStmt(((PrepareStmt) prepareStmt));
if (execStmt != null && ((PrepareStmt) prepareStmt).getPreparedType() != PreparedType.FULL_PREPARED) {
((PrepareStmt) prepareStmt).asignValues(execStmt.getArgs());
}
}
parsedStmt.analyze(analyzer);
@ -1342,9 +1351,10 @@ public class StmtExecutor {
// query re-analyze
parsedStmt.reset();
if (prepareStmt != null) {
analyzer.setPrepareStmt(prepareStmt);
if (execStmt != null && prepareStmt.getPreparedType() != PreparedType.FULL_PREPARED) {
prepareStmt.asignValues(execStmt.getArgs());
analyzer.setPrepareStmt(((PrepareStmt) prepareStmt));
if (execStmt != null
&& ((PrepareStmt) prepareStmt).getPreparedType() != PreparedType.FULL_PREPARED) {
((PrepareStmt) prepareStmt).asignValues(execStmt.getArgs());
}
}
analyzer.setReAnalyze(true);
@ -2373,16 +2383,17 @@ public class StmtExecutor {
}
private void handlePrepareStmt() throws Exception {
List<String> labels = ((PrepareStmt) prepareStmt).getColLabelsOfPlaceHolders();
// register prepareStmt
if (LOG.isDebugEnabled()) {
LOG.debug("add prepared statement {}, isBinaryProtocol {}",
prepareStmt.getName(), context.getCommand() == MysqlCommand.COM_STMT_PREPARE);
prepareStmt.toSql(), context.getCommand() == MysqlCommand.COM_STMT_PREPARE);
}
context.addPreparedStmt(prepareStmt.getName(),
context.addPreparedStmt(String.valueOf(context.getStmtId()),
new PrepareStmtContext(prepareStmt,
context, planner, analyzer, prepareStmt.getName()));
context, planner, analyzer, String.valueOf(context.getStmtId())));
if (context.getCommand() == MysqlCommand.COM_STMT_PREPARE) {
sendStmtPrepareOK();
sendStmtPrepareOK((int) context.getStmtId(), labels);
}
}
@ -2427,19 +2438,19 @@ public class StmtExecutor {
return exprs.stream().map(e -> PrimitiveType.STRING).collect(Collectors.toList());
}
private void sendStmtPrepareOK() throws IOException {
public void sendStmtPrepareOK(int stmtId, List<String> labels) throws IOException {
Preconditions.checkState(context.getConnectType() == ConnectType.MYSQL);
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response
serializer.reset();
// 0x00 OK
serializer.writeInt1(0);
// statement_id
serializer.writeInt4(Integer.valueOf(prepareStmt.getName()));
serializer.writeInt4(stmtId);
// num_columns
int numColumns = 0;
serializer.writeInt2(numColumns);
// num_params
int numParams = prepareStmt.getColLabelsOfPlaceHolders().size();
int numParams = labels.size();
serializer.writeInt2(numParams);
// reserved_1
serializer.writeInt1(0);
@ -2448,14 +2459,12 @@ public class StmtExecutor {
// send field one by one
// TODO use real type instead of string, for JDBC client it's ok
// but for other client, type should be correct
List<PrimitiveType> types = exprToStringType(prepareStmt.getPlaceHolderExprList());
List<String> colNames = prepareStmt.getColLabelsOfPlaceHolders();
if (LOG.isDebugEnabled()) {
LOG.debug("sendFields {}, {}", colNames, types);
}
// List<PrimitiveType> types = exprToStringType(labels);
List<String> colNames = labels;
for (int i = 0; i < colNames.size(); ++i) {
serializer.reset();
serializer.writeField(colNames.get(i), Type.fromPrimitiveType(types.get(i)));
// serializer.writeField(colNames.get(i), Type.fromPrimitiveType(types.get(i)));
serializer.writeField(colNames.get(i), Type.STRING);
context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
}
serializer.reset();
@ -2484,14 +2493,15 @@ public class StmtExecutor {
// send field one by one
for (int i = 0; i < colNames.size(); ++i) {
serializer.reset();
if (prepareStmt != null && isExecuteStmt) {
if (prepareStmt != null && prepareStmt instanceof PrepareStmt
&& context.getCommand() == MysqlCommand.COM_STMT_EXECUTE) {
// Using PreparedStatment pre serializedField to avoid serialize each time
// we send a field
byte[] serializedField = prepareStmt.getSerializedField(colNames.get(i));
byte[] serializedField = ((PrepareStmt) prepareStmt).getSerializedField(colNames.get(i));
if (serializedField == null) {
serializer.writeField(colNames.get(i), types.get(i));
serializedField = serializer.toArray();
prepareStmt.setSerializedField(colNames.get(i), serializedField);
((PrepareStmt) prepareStmt).setSerializedField(colNames.get(i), serializedField);
}
context.getMysqlChannel().sendOnePacket(ByteBuffer.wrap(serializedField));
} else {