[Fix](mysql proto) avoid send duplicated OK packet (#21032)

1. The Mysql Go driver has a logic that terminates when it reads an EOF (end-of-file) and expects no data in the buffer. However, the front-end (FE) mistakenly returns an additional OK packet, which causes an exception to be thrown when reading the buffer.

2. Refactor some logic to support full prepared not just in where clause, like 
```
select ?, ? from tbl
```
This commit is contained in:
lihangyu
2023-06-21 12:00:22 +08:00
committed by GitHub
parent 18beb822a3
commit fcd778fb4f
8 changed files with 92 additions and 68 deletions

View File

@ -67,6 +67,7 @@ parser code {:
public boolean isVerbose = false;
public String wild;
public Expr where;
public ArrayList<PlaceHolderExpr> placeholder_expr_list = Lists.newArrayList();
// List of expected tokens ids from current parsing state for generating syntax error message
private final List<Integer> expectedTokenIds = Lists.newArrayList();
@ -1074,7 +1075,11 @@ stmt ::=
| switch_stmt:stmt
{: RESULT = stmt; :}
| query_stmt:query
{: RESULT = query; :}
{:
RESULT = query;
query.setPlaceHolders(parser.placeholder_expr_list);
parser.placeholder_expr_list.clear();
:}
| drop_stmt:stmt
{: RESULT = stmt; :}
| recover_stmt:stmt
@ -5185,6 +5190,8 @@ prepare_stmt ::=
KW_PREPARE variable_name:name KW_FROM select_stmt:s
{:
RESULT = new PrepareStmt(s, name, false);
s.setPlaceHolders(parser.placeholder_expr_list);
parser.placeholder_expr_list.clear();
:}
;
@ -6741,9 +6748,9 @@ literal ::=
| KW_NULL
{: RESULT = new NullLiteral(); :}
| PLACEHOLDER
{: RESULT = new PlaceHolderExpr(); :}
{: RESULT = new PlaceHolderExpr(); parser.placeholder_expr_list.add((PlaceHolderExpr) RESULT); :}
| MOD
{: RESULT = new PlaceHolderExpr(); :}
{: RESULT = new PlaceHolderExpr(); parser.placeholder_expr_list.add((PlaceHolderExpr) RESULT); :}
| UNMATCHED_STRING_LITERAL:l expr:e
{:
// we have an unmatched string literal.

View File

@ -17,7 +17,6 @@
package org.apache.doris.analysis;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.common.UserException;
import org.apache.doris.qe.ConnectContext;
@ -42,9 +41,6 @@ public class PrepareStmt extends StatementBase {
private static final Logger LOG = LogManager.getLogger(PrepareStmt.class);
private StatementBase inner;
private String stmtName;
// select * from tbl where a = ? and b = ?
// `?` is the placeholder
protected List<PlaceHolderExpr> placeholders = new ArrayList<>();
// Cached for better CPU performance, since serialize DescriptorTable and
// outputExprs are heavy work
@ -89,10 +85,6 @@ public class PrepareStmt extends StatementBase {
return id;
}
public List<PlaceHolderExpr> placeholders() {
return this.placeholders;
}
public boolean isBinaryProtocol() {
return binaryRowFormat;
}
@ -139,46 +131,6 @@ public class PrepareStmt extends StatementBase {
return serializedOutputExpr;
}
public int getParmCount() {
return placeholders.size();
}
public List<Expr> getSlotRefOfPlaceHolders() {
ArrayList<Expr> slots = new ArrayList<>();
if (inner instanceof SelectStmt) {
SelectStmt select = (SelectStmt) inner;
for (PlaceHolderExpr pexpr : placeholders) {
// Only point query support
for (Map.Entry<SlotRef, Expr> entry :
select.getPointQueryEQPredicates().entrySet()) {
// same instance
if (entry.getValue() == pexpr) {
slots.add(entry.getKey());
}
}
}
return slots;
}
return null;
}
public List<String> getColLabelsOfPlaceHolders() {
ArrayList<String> lables = new ArrayList<>();
if (inner instanceof SelectStmt) {
for (Expr slotExpr : getSlotRefOfPlaceHolders()) {
SlotRef slot = (SlotRef) slotExpr;
Column c = slot.getColumn();
if (c != null) {
lables.add(c.getName());
continue;
}
lables.add("");
}
return lables;
}
return null;
}
@Override
public void analyze(Analyzer analyzer) throws UserException {
if (!(inner instanceof SelectStmt)) {
@ -186,12 +138,7 @@ public class PrepareStmt extends StatementBase {
}
// Use tmpAnalyzer since selectStmt will be reAnalyzed
Analyzer tmpAnalyzer = new Analyzer(context.getEnv(), context);
// collect placeholders from stmt exprs tree
SelectStmt selectStmt = (SelectStmt) inner;
// TODO(lhy) support more clauses
if (selectStmt.getWhereClause() != null) {
selectStmt.getWhereClause().collect(PlaceHolderExpr.class, placeholders);
}
inner.analyze(tmpAnalyzer);
if (!selectStmt.checkAndSetPointQuery()) {
throw new UserException("Only support prepare SelectStmt point query now");
@ -217,17 +164,40 @@ public class PrepareStmt extends StatementBase {
return inner;
}
public int argsSize() {
return placeholders.size();
public List<PlaceHolderExpr> placeholders() {
return inner.getPlaceHolders();
}
public int getParmCount() {
return inner.getPlaceHolders().size();
}
public List<Expr> getPlaceHolderExprList() {
ArrayList<Expr> slots = new ArrayList<>();
for (PlaceHolderExpr pexpr : inner.getPlaceHolders()) {
slots.add(pexpr);
}
return slots;
}
public List<String> getColLabelsOfPlaceHolders() {
ArrayList<String> lables = new ArrayList<>();
for (int i = 0; i < inner.getPlaceHolders().size(); ++i) {
lables.add("lable " + i);
}
return lables;
}
public void asignValues(List<LiteralExpr> values) throws UserException {
if (values.size() != placeholders.size()) {
if (values.size() != inner.getPlaceHolders().size()) {
throw new UserException("Invalid arguments size "
+ values.size() + ", expected " + placeholders.size());
+ values.size() + ", expected " + inner.getPlaceHolders().size());
}
for (int i = 0; i < values.size(); ++i) {
placeholders.get(i).setLiteral(values.get(i));
inner.getPlaceHolders().get(i).setLiteral(values.get(i));
}
if (!values.isEmpty()) {
LOG.debug("assign values {}", values.get(0).toSql());
}
}
@ -237,7 +207,6 @@ public class PrepareStmt extends StatementBase {
serializedOutputExpr = null;
descTable = null;
this.id = UUID.randomUUID();
placeholders.clear();
inner.reset();
serializedFields.clear();
}

View File

@ -32,6 +32,7 @@ import org.apache.doris.thrift.TQueryOptions;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@ -57,6 +58,10 @@ public abstract class StatementBase implements ParseNode {
private boolean isPrepared = false;
// select * from tbl where a = ? and b = ?
// `?` is the placeholder
private ArrayList<PlaceHolderExpr> placeholders = new ArrayList<>();
protected StatementBase() { }
/**
@ -101,6 +106,14 @@ public abstract class StatementBase implements ParseNode {
return this.explainOptions != null;
}
public void setPlaceHolders(ArrayList<PlaceHolderExpr> placeholders) {
this.placeholders = new ArrayList<PlaceHolderExpr>(placeholders);
}
public ArrayList<PlaceHolderExpr> getPlaceHolders() {
return this.placeholders;
}
public boolean isVerbose() {
return explainOptions != null && explainOptions.isVerbose();
}

View File

@ -157,6 +157,10 @@ public class MysqlCapability {
return (flags & Flag.CLIENT_LOCAL_FILES.getFlagBit()) != 0;
}
public boolean isDeprecatedEOF() {
return (flags & Flag.CLIENT_DEPRECATE_EOF.getFlagBit()) != 0;
}
@Override
public boolean equals(Object obj) {
if (obj == null || !(obj instanceof MysqlCapability)) {

View File

@ -75,10 +75,21 @@ public class MysqlChannel {
protected volatile MysqlSerializer serializer;
// mysql flag CLIENT_DEPRECATE_EOF
private boolean clientDeprecatedEOF;
protected MysqlChannel() {
// For DummyMysqlChannel
}
public void setClientDeprecatedEOF() {
clientDeprecatedEOF = true;
}
public boolean clientDeprecatedEOF() {
return clientDeprecatedEOF;
}
public MysqlChannel(StreamConnection connection) {
Preconditions.checkNotNull(connection);
this.sequenceId = 0;

View File

@ -203,6 +203,9 @@ public class MysqlProto {
// receive response failed.
return false;
}
if (capability.isDeprecatedEOF()) {
context.getMysqlChannel().setClientDeprecatedEOF();
}
MysqlAuthPacket authPacket = new MysqlAuthPacket();
if (!authPacket.readFrom(handshakeResponse)) {
ErrorReport.report(ErrorCode.ERR_NOT_SUPPORTED_AUTH_MODE);

View File

@ -561,6 +561,7 @@ public class ConnectProcessor {
LOG.warn("Unknown command(" + code + ")");
return;
}
LOG.debug("handle command {}", command);
ctx.setCommand(command);
ctx.setStartTime();

View File

@ -1922,8 +1922,6 @@ public class StmtExecutor {
if (prepareStmt.isBinaryProtocol()) {
sendStmtPrepareOK();
}
// context.getState().setEof();
context.getState().setOk();
}
@ -1965,6 +1963,10 @@ public class StmtExecutor {
context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
}
private List<PrimitiveType> exprToStringType(List<Expr> exprs) {
return exprs.stream().map(e -> PrimitiveType.STRING).collect(Collectors.toList());
}
private void sendStmtPrepareOK() throws IOException {
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response
serializer.reset();
@ -1979,13 +1981,27 @@ public class StmtExecutor {
int numParams = prepareStmt.getColLabelsOfPlaceHolders().size();
serializer.writeInt2(numParams);
// reserved_1
// serializer.writeInt1(0);
serializer.writeInt1(0);
context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
if (numParams > 0) {
sendFields(prepareStmt.getColLabelsOfPlaceHolders(),
exprToType(prepareStmt.getSlotRefOfPlaceHolders()));
// send field one by one
// TODO use real type instead of string, for JDBC client it's ok
// but for other client, type should be correct
List<PrimitiveType> types = exprToStringType(prepareStmt.getPlaceHolderExprList());
List<String> colNames = prepareStmt.getColLabelsOfPlaceHolders();
LOG.debug("sendFields {}, {}", colNames, types);
for (int i = 0; i < colNames.size(); ++i) {
serializer.reset();
serializer.writeField(colNames.get(i), Type.fromPrimitiveType(types.get(i)));
context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
}
}
// send EOF if nessessary
if (!context.getMysqlChannel().clientDeprecatedEOF()) {
context.getState().setEof();
} else {
context.getState().setOk();
}
context.getState().setOk();
}
private void sendFields(List<String> colNames, List<Type> types) throws IOException {