[feature](mysql) Support secure MySQL connection to FE (#17138)

Background:
Doris currently does not support SSL connection from MySQL clients, it's not secure enough in some cases, especially access Doris via the public internet.

Solution:
- Use TLS1.2 protocol to encrypt information.
- Implementation details
  * server <--- connect <--- client
  * if enable SSL: {
  * server <--- SSL connection request packet <--- client
  * server <--- SSL Exchange ---> client } (we will add this `if` logic part in this PR)
  * server ---> handshake request packet ---> client
  * server <--- encrypted data ---> client (this part will be realized in this PR)
- reference1 https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html#sect_protocol_connection_phase_initial_handshake_ssl_handshake
- reference2 https://www.rfc-editor.org/rfc/rfc5246

close #16313

Signed-off-by: Yukang Lian <yukang.lian2022@gmail.com>
Co-authored-by: Gavin Chou <gavineaglechou@gmail.com>
Co-authored-by: morningman <morningman@163.com>
This commit is contained in:
abmdocrt
2023-03-04 12:14:48 +08:00
committed by GitHub
parent 9f7386243f
commit 82df2ae9d8
22 changed files with 867 additions and 48 deletions

View File

@ -1188,7 +1188,10 @@ public enum ErrorCode {
+ "the length of table name '%s' is %d which is greater than the configuration 'table_name_length_limit' (%d)."),
ERR_NONSUPPORT_TIME_TRAVEL_TABLE(5090, new byte[]{'4', '2', '0', '0', '0'}, "Only iceberg external"
+ " table supports time travel in current version");
+ " table supports time travel in current version"),
ERR_NONSSL_HANDSHAKE_RESPONSE(5091, new byte[] {'4', '2', '0', '0'},
"SSL mode on but received non-ssl handshake response from client.");
// This is error code
private final int code;

View File

@ -50,7 +50,7 @@ public class DummyMysqlChannel extends MysqlChannel {
}
@Override
protected int readAll(ByteBuffer dstBuf) throws IOException {
protected int readAll(ByteBuffer dstBuf, boolean isHeader) throws IOException {
return 0;
}

View File

@ -75,7 +75,14 @@ public class MysqlCapability {
private static final int DEFAULT_FLAGS = Flag.CLIENT_PROTOCOL_41.getFlagBit()
| Flag.CLIENT_CONNECT_WITH_DB.getFlagBit() | Flag.CLIENT_SECURE_CONNECTION.getFlagBit()
| Flag.CLIENT_PLUGIN_AUTH.getFlagBit() | Flag.CLIENT_LOCAL_FILES.getFlagBit();
private static final int SSL_FLAGS = Flag.CLIENT_PROTOCOL_41.getFlagBit()
| Flag.CLIENT_CONNECT_WITH_DB.getFlagBit() | Flag.CLIENT_SECURE_CONNECTION.getFlagBit()
| Flag.CLIENT_PLUGIN_AUTH.getFlagBit() | Flag.CLIENT_LOCAL_FILES.getFlagBit()
| Flag.CLIENT_SSL.getFlagBit();
public static final MysqlCapability DEFAULT_CAPABILITY = new MysqlCapability(DEFAULT_FLAGS);
public static final MysqlCapability SSL_CAPABILITY = new MysqlCapability(SSL_FLAGS);
private int flags;
@ -112,6 +119,10 @@ public class MysqlCapability {
return (flags & Flag.CLIENT_PROTOCOL_41.getFlagBit()) != 0;
}
public boolean isClientUseSsl() {
return (flags & Flag.CLIENT_SSL.getFlagBit()) != 0;
}
public boolean isTransactions() {
return (flags & Flag.CLIENT_TRANSACTIONS.getFlagBit()) != 0;

View File

@ -23,12 +23,16 @@ import org.apache.doris.qe.ConnectProcessor;
import com.google.common.base.Preconditions;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.xnio.StreamConnection;
import org.xnio.channels.Channels;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
/**
* This class used to read/write MySQL logical packet.
@ -43,6 +47,8 @@ public class MysqlChannel {
public static final int MAX_PHYSICAL_PACKET_LENGTH = 0xffffff;
// MySQL packet header length
protected static final int PACKET_HEADER_LEN = 4;
// SSL packet header length
protected static final int SSL_PACKET_HEADER_LEN = 5;
// next sequence id to receive or send
protected int sequenceId;
// channel connected with client
@ -50,13 +56,23 @@ public class MysqlChannel {
// used to receive/send header, avoiding new this many time.
protected ByteBuffer headerByteBuffer;
protected ByteBuffer defaultBuffer;
// default packet byte buffer for most packet
protected ByteBuffer sslHeaderByteBuffer;
protected ByteBuffer tempBuffer;
protected ByteBuffer remainingBuffer;
protected ByteBuffer sendBuffer;
protected ByteBuffer decryptAppData;
protected ByteBuffer encryptNetData;
// for log and show
protected String remoteHostPortString;
protected String remoteIp;
protected boolean isSend;
// Serializer used to pack MySQL packet.
protected boolean isSslMode;
protected boolean isSslHandshaking;
private SSLEngine sslEngine;
protected volatile MysqlSerializer serializer;
protected MysqlChannel() {
@ -86,6 +102,14 @@ public class MysqlChannel {
this.sendBuffer = ByteBuffer.allocate(2 * 1024 * 1024);
}
public void initSslBuffer() {
// allocate buffer when needed.
this.remainingBuffer = ByteBuffer.allocate(16 * 1024);
this.remainingBuffer.flip();
this.tempBuffer = ByteBuffer.allocate(16 * 1024);
this.sslHeaderByteBuffer = ByteBuffer.allocate(SSL_PACKET_HEADER_LEN);
}
public void setSequenceId(int sequenceId) {
this.sequenceId = sequenceId;
}
@ -94,14 +118,37 @@ public class MysqlChannel {
return remoteIp;
}
public void setSslEngine(SSLEngine sslEngine) {
this.sslEngine = sslEngine;
decryptAppData = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize() * 2);
encryptNetData = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize() * 2);
}
public void setSslMode(boolean sslMode) {
isSslMode = sslMode;
if (isSslMode) {
// channel in ssl mode means handshake phase has finished.
isSslHandshaking = false;
}
}
public void setSslHandshaking(boolean sslHandshaking) {
isSslHandshaking = sslHandshaking;
}
private int packetId() {
byte[] header = headerByteBuffer.array();
return header[3] & 0xFF;
}
private int packetLen() {
byte[] header = headerByteBuffer.array();
return (header[0] & 0xFF) | ((header[1] & 0XFF) << 8) | ((header[2] & 0XFF) << 16);
private int packetLen(boolean isSslHeader) {
if (isSslHeader) {
byte[] header = sslHeaderByteBuffer.array();
return (header[4] & 0xFF) | ((header[3] & 0XFF) << 8);
} else {
byte[] header = headerByteBuffer.array();
return (header[0] & 0xFF) | ((header[1] & 0XFF) << 8) | ((header[2] & 0XFF) << 16);
}
}
private void accSequenceId() {
@ -120,17 +167,30 @@ public class MysqlChannel {
}
}
protected int readAll(ByteBuffer dstBuf) throws IOException {
// all packet header is not encrypted, packet body is not sure.
protected int readAll(ByteBuffer dstBuf, boolean isHeader) throws IOException {
int readLen = 0;
if (!dstBuf.hasRemaining()) {
return 0;
}
if (remainingBuffer != null && remainingBuffer.hasRemaining()) {
int oldLen = dstBuf.position();
while (dstBuf.hasRemaining()) {
dstBuf.put(remainingBuffer.get());
}
return dstBuf.position() - oldLen;
}
try {
while (dstBuf.remaining() != 0) {
int ret = Channels.readBlocking(conn.getSourceChannel(), dstBuf);
// return -1 when remote peer close the channel
if (ret == -1) {
decryptData(dstBuf, isHeader);
return readLen;
}
readLen += ret;
}
decryptData(dstBuf, isHeader);
} catch (IOException e) {
LOG.debug("Read channel exception, ignore.", e);
return 0;
@ -138,51 +198,130 @@ public class MysqlChannel {
return readLen;
}
protected void decryptData(ByteBuffer dstBuf, boolean isHeader) throws SSLException {
// after decrypt, we get a mysql packet with mysql header.
if (!isSslMode || isHeader) {
return;
}
dstBuf.flip();
decryptAppData.clear();
// unwrap will remove ssl header.
while (true) {
SSLEngineResult result = sslEngine.unwrap(dstBuf, decryptAppData);
if (handleUnwrapResult(result) && !dstBuf.hasRemaining()) {
break;
}
// if BUFFER_OVERFLOW or BUFFER_UNDERFLOW, need to unwrap again, so we do nothing.
}
decryptAppData.flip();
dstBuf.clear();
dstBuf.put(decryptAppData);
dstBuf.flip();
}
// read one logical mysql protocol packet
// null for channel is closed.
// NOTE: all of the following code is assumed that the channel is in block mode.
// if in handshaking mode we return a packet with header otherwise without header.
public ByteBuffer fetchOnePacket() throws IOException {
int readLen;
ByteBuffer result = defaultBuffer;
result.clear();
while (true) {
headerByteBuffer.clear();
readLen = readAll(headerByteBuffer);
if (readLen != PACKET_HEADER_LEN) {
// remote has close this channel
LOG.debug("Receive packet header failed, remote may close the channel.");
return null;
}
if (packetId() != sequenceId) {
LOG.warn("receive packet sequence id[" + packetId() + "] want to get[" + sequenceId + "]");
throw new IOException("Bad packet sequence.");
}
int packetLen = packetLen();
if ((result.capacity() - result.position()) < packetLen) {
// byte buffer is not enough, new one packet
ByteBuffer tmp;
if (packetLen < MAX_PHYSICAL_PACKET_LENGTH) {
// last packet, enough to this packet is OK.
tmp = ByteBuffer.allocate(packetLen + result.position());
} else {
// already have packet, to allocate two packet.
tmp = ByteBuffer.allocate(2 * packetLen + result.position());
int packetLen;
// one SSL packet may include multiple Mysql packets, we use remainingBuffer to store them.
if ((isSslMode || isSslHandshaking) && !remainingBuffer.hasRemaining()) {
if (remainingBuffer.position() != 0) {
remainingBuffer.clear();
remainingBuffer.flip();
}
tmp.put(result.array(), 0, result.position());
result = tmp;
sslHeaderByteBuffer.clear();
readLen = readAll(sslHeaderByteBuffer, true);
if (readLen != SSL_PACKET_HEADER_LEN) {
// remote has close this channel
LOG.debug("Receive ssl packet header failed, remote may close the channel.");
return null;
}
// when handshaking and ssl mode, sslengine unwrap need a packet with header.
result.put(sslHeaderByteBuffer.array());
packetLen = packetLen(true);
} else {
headerByteBuffer.clear();
readLen = readAll(headerByteBuffer, true);
if (readLen != PACKET_HEADER_LEN) {
// remote has close this channel
LOG.debug("Receive packet header failed, remote may close the channel.");
return null;
}
if (packetId() != sequenceId) {
LOG.warn("receive packet sequence id[" + packetId() + "] want to get[" + sequenceId + "]");
throw new IOException("Bad packet sequence.");
}
packetLen = packetLen(false);
}
result = expandPacket(result, packetLen);
// read one physical packet
// before read, set limit to make read only one packet
result.limit(result.position() + packetLen);
readLen = readAll(result);
readLen = readAll(result, false);
if (isSslMode && remainingBuffer.position() == 0) {
byte[] header = result.array();
int packetId = header[3] & 0xFF;
if (packetId != sequenceId) {
LOG.warn("receive packet sequence id[" + packetId() + "] want to get[" + sequenceId + "]");
throw new IOException("Bad packet sequence.");
}
int mysqlPacketLength = (header[0] & 0xFF) | ((header[1] & 0XFF) << 8) | ((header[2] & 0XFF) << 16);
// remove mysql packet header
result.position(4);
result.compact();
// when encounter large sql query, one mysql packet will be packed as multiple ssl packets.
// we need to read all ssl packets to combine the complete mysql packet.
while (mysqlPacketLength > result.limit()) {
sslHeaderByteBuffer.clear();
readLen = readAll(sslHeaderByteBuffer, true);
if (readLen != SSL_PACKET_HEADER_LEN) {
// remote has close this channel
LOG.debug("Receive ssl packet header failed, remote may close the channel.");
return null;
}
tempBuffer.clear();
tempBuffer.put(sslHeaderByteBuffer.array());
packetLen = packetLen(true);
LOG.info("one ssl packet length is: " + packetLen);
tempBuffer = expandPacket(tempBuffer, packetLen);
result = expandPacket(result, tempBuffer.capacity());
// read one physical packet
// before read, set limit to make read only one packet
tempBuffer.limit(tempBuffer.position() + packetLen);
readLen = readAll(tempBuffer, false);
result.put(tempBuffer);
result.limit(result.position());
LOG.info("result is pos: " + result.position() + ", limit: "
+ result.limit() + "capacity: " + result.capacity());
}
if (mysqlPacketLength < result.position()) {
LOG.info("one SSL packet has multiple mysql packets.");
LOG.info("mysql packet length is " + mysqlPacketLength + ", result is pos: "
+ result.position() + ", limit: " + result.limit() + "capacity: " + result.capacity());
result.flip();
result.position(mysqlPacketLength);
remainingBuffer.clear();
remainingBuffer.put(result);
remainingBuffer.flip();
}
result.position(mysqlPacketLength);
}
if (readLen != packetLen) {
LOG.warn("Length of received packet content(" + readLen
+ ") is not equal with length in head.(" + packetLen + ")");
return null;
}
accSequenceId();
if (!isSslHandshaking) {
accSequenceId();
}
if (packetLen != MAX_PHYSICAL_PACKET_LENGTH) {
result.flip();
break;
@ -191,7 +330,27 @@ public class MysqlChannel {
return result;
}
@NotNull
private ByteBuffer expandPacket(ByteBuffer result, int packetLen) {
if ((result.capacity() - result.position()) < packetLen) {
// byte buffer is not enough, new one packet
ByteBuffer tmp;
if (packetLen < MAX_PHYSICAL_PACKET_LENGTH) {
// last packet, enough to this packet is OK.
tmp = ByteBuffer.allocate(packetLen + result.position());
} else {
// already have packet, to allocate two packet.
tmp = ByteBuffer.allocate(2 * packetLen + result.position());
}
tmp.put(result.array(), 0, result.position());
result = tmp;
}
result.limit(result.position() + packetLen);
return result;
}
protected void realNetSend(ByteBuffer buffer) throws IOException {
encryptData(buffer);
long bufLen = buffer.remaining();
long writeLen = Channels.writeBlocking(conn.getSinkChannel(), buffer);
if (bufLen != writeLen) {
@ -202,6 +361,23 @@ public class MysqlChannel {
isSend = true;
}
protected void encryptData(ByteBuffer dstBuf) throws SSLException {
if (!isSslMode) {
return;
}
encryptNetData.clear();
while (true) {
SSLEngineResult result = sslEngine.wrap(dstBuf, encryptNetData);
if (handleWrapResult(result) && !dstBuf.hasRemaining()) {
break;
}
}
encryptNetData.flip();
dstBuf.clear();
dstBuf.put(encryptNetData);
dstBuf.flip();
}
public void flush() throws IOException {
if (null == sendBuffer || sendBuffer.position() == 0) {
// Nothing to send
@ -213,7 +389,7 @@ public class MysqlChannel {
isSend = true;
}
private void writeHeader(int length) throws IOException {
private void writeHeader(int length, boolean isSsl) throws IOException {
if (null == sendBuffer) {
return;
}
@ -230,7 +406,7 @@ public class MysqlChannel {
sendBuffer.put((byte) sequenceId);
}
private void writeBuffer(ByteBuffer buffer) throws IOException {
private void writeBuffer(ByteBuffer buffer, boolean isSsl) throws IOException {
if (null == sendBuffer) {
return;
}
@ -250,19 +426,30 @@ public class MysqlChannel {
}
public void sendOnePacket(ByteBuffer packet) throws IOException {
// handshake in packet with header and has encrypted, need to send in ssl format
// ssl mode in packet no header and no encrypted, need to encrypted and add header and send in ssl format
int bufLen;
int oldLimit = packet.limit();
while (oldLimit - packet.position() >= MAX_PHYSICAL_PACKET_LENGTH) {
bufLen = MAX_PHYSICAL_PACKET_LENGTH;
packet.limit(packet.position() + bufLen);
writeHeader(bufLen);
writeBuffer(packet);
if (isSslHandshaking) {
writeBuffer(packet, true);
} else {
writeHeader(bufLen, isSslMode);
writeBuffer(packet, isSslMode);
accSequenceId();
}
}
if (isSslHandshaking) {
packet.limit(oldLimit);
writeBuffer(packet, true);
} else {
writeHeader(oldLimit - packet.position(), isSslMode);
packet.limit(oldLimit);
writeBuffer(packet, isSslMode);
accSequenceId();
}
writeHeader(oldLimit - packet.position());
packet.limit(oldLimit);
writeBuffer(packet);
accSequenceId();
}
public void sendAndFlush(ByteBuffer packet) throws IOException {
@ -306,4 +493,52 @@ public class MysqlChannel {
public MysqlSerializer getSerializer() {
return serializer;
}
private boolean handleWrapResult(SSLEngineResult sslEngineResult) throws SSLException {
switch (sslEngineResult.getStatus()) {
// normal status.
case OK:
return true;
case CLOSED:
sslEngine.closeOutbound();
return true;
case BUFFER_OVERFLOW:
// Could attempt to drain the serverNetData buffer of any already obtained
// data, but we'll just increase it to the size needed.
ByteBuffer newBuffer = ByteBuffer.allocate(encryptNetData.capacity() * 2);
encryptNetData.flip();
newBuffer.put(encryptNetData);
encryptNetData = newBuffer;
// retry the operation.
return false;
// when wrap BUFFER_UNDERFLOW and other status will not appear.
case BUFFER_UNDERFLOW:
default:
throw new IllegalStateException("invalid wrap status: " + sslEngineResult.getStatus());
}
}
private boolean handleUnwrapResult(SSLEngineResult sslEngineResult) {
switch (sslEngineResult.getStatus()) {
// normal status.
case OK:
return true;
case CLOSED:
sslEngine.closeOutbound();
return true;
case BUFFER_OVERFLOW:
// Could attempt to drain the clientAppData buffer of any already obtained
// data, but we'll just increase it to the size needed.
ByteBuffer newAppBuffer = ByteBuffer.allocate(decryptAppData.capacity() * 2);
decryptAppData.flip();
newAppBuffer.put(decryptAppData);
decryptAppData = newAppBuffer;
// retry the operation.
return false;
case BUFFER_UNDERFLOW:
default:
throw new IllegalStateException("invalid wrap status: " + sslEngineResult.getStatus());
}
}
}

View File

@ -30,6 +30,7 @@ public class MysqlHandshakePacket extends MysqlPacket {
private static final int CHARACTER_SET = 33;
// use default capability for all
private static final MysqlCapability CAPABILITY = MysqlCapability.DEFAULT_CAPABILITY;
private static final MysqlCapability SSL_CAPABILITY = MysqlCapability.SSL_CAPABILITY;
// status flags not supported in palo
private static final int STATUS_FLAGS = 0;
private static final String AUTH_PLUGIN_NAME = "mysql_native_password";
@ -49,7 +50,7 @@ public class MysqlHandshakePacket extends MysqlPacket {
@Override
public void writeTo(MysqlSerializer serializer) {
MysqlCapability capability = CAPABILITY;
MysqlCapability capability = MysqlProto.SERVER_USE_SSL ? SSL_CAPABILITY : CAPABILITY;
serializer.writeInt1(PROTOCOL_VERSION);
serializer.writeNulTerminateString(SERVER_VERSION);

View File

@ -45,6 +45,7 @@ import java.util.List;
// MySQL protocol util
public class MysqlProto {
private static final Logger LOG = LogManager.getLogger(MysqlProto.class);
public static final boolean SERVER_USE_SSL = Config.enable_ssl;
// scramble: data receive from server.
// randomString: data send by server in plug-in data field
@ -170,8 +171,68 @@ public class MysqlProto {
LOG.debug("Send and flush channel exception, ignore.", e);
return false;
}
// Server receive request packet from client, we need to determine which request type it is.
ByteBuffer clientRequestPacket = channel.fetchOnePacket();
MysqlCapability capability = new MysqlCapability(MysqlProto.readLowestInt4(clientRequestPacket));
// Server receive SSL connection request packet from client.
ByteBuffer sslConnectionRequest;
// Server receive authenticate packet from client.
ByteBuffer handshakeResponse = channel.fetchOnePacket();
ByteBuffer handshakeResponse;
if (capability.isClientUseSsl()) {
LOG.info("client is using ssl connection.");
// During development, we set SSL mode to true by default.
if (SERVER_USE_SSL) {
LOG.info("server is also using ssl connection. Will use ssl mode for data exchange.");
MysqlSslContext mysqlSslContext = context.getMysqlSslContext();
mysqlSslContext.init();
channel.initSslBuffer();
sslConnectionRequest = clientRequestPacket;
if (sslConnectionRequest == null) {
// receive response failed.
return false;
}
MysqlSslPacket sslPacket = new MysqlSslPacket();
if (!sslPacket.readFrom(sslConnectionRequest)) {
ErrorReport.report(ErrorCode.ERR_NOT_SUPPORTED_AUTH_MODE);
sendResponsePacket(context);
return false;
}
// try to establish ssl connection.
try {
// set channel to handshake mode to process data packet as ssl packet.
channel.setSslHandshaking(true);
// The ssl handshake phase still uses plaintext.
if (!mysqlSslContext.sslExchange(channel)) {
ErrorReport.report(ErrorCode.ERR_NOT_SUPPORTED_AUTH_MODE);
sendResponsePacket(context);
return false;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
// if the exchange is successful, the channel will switch to ssl communication mode
// which means all data after this moment will be ciphertext.
// Set channel mode to ssl mode to handle socket packet in ssl format.
channel.setSslMode(true);
LOG.info("switch to ssl mode.");
handshakeResponse = channel.fetchOnePacket();
capability = new MysqlCapability(MysqlProto.readLowestInt4(handshakeResponse));
if (!capability.isClientUseSsl()) {
ErrorReport.report(ErrorCode.ERR_NONSSL_HANDSHAKE_RESPONSE);
sendResponsePacket(context);
return false;
}
} else {
handshakeResponse = clientRequestPacket;
}
} else {
handshakeResponse = clientRequestPacket;
}
if (handshakeResponse == null) {
// receive response failed.
return false;
@ -324,6 +385,10 @@ public class MysqlProto {
return buffer.get();
}
public static byte readByteAt(ByteBuffer buffer, int index) {
return buffer.get(index);
}
public static int readInt1(ByteBuffer buffer) {
return readByte(buffer) & 0XFF;
}
@ -337,6 +402,11 @@ public class MysqlProto {
buffer) & 0xFF) << 16);
}
public static int readLowestInt4(ByteBuffer buffer) {
return (readByteAt(buffer, 0) & 0xFF) | ((readByteAt(buffer, 1) & 0xFF) << 8) | ((readByteAt(
buffer, 2) & 0xFF) << 16) | ((readByteAt(buffer, 3) & 0XFF) << 24);
}
public static int readInt4(ByteBuffer buffer) {
return (readByte(buffer) & 0xFF) | ((readByte(buffer) & 0xFF) << 8) | ((readByte(
buffer) & 0xFF) << 16) | ((readByte(buffer) & 0XFF) << 24);

View File

@ -0,0 +1,277 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.mysql;
import org.apache.doris.common.Config;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManagerFactory;
public class MysqlSslContext {
private static final Logger LOG = LogManager.getLogger(MysqlSslContext.class);
private SSLEngine sslEngine;
private SSLContext sslContext;
private String protocol;
private ByteBuffer serverAppData;
private static final String keyStoreFile = Config.mysql_ssl_default_certificate;
private static final String trustStoreFile = Config.mysql_ssl_default_certificate;
private static final String certificatePassword = Config.mysql_ssl_default_certificate_password;
private ByteBuffer serverNetData;
private ByteBuffer clientAppData;
private ByteBuffer clientNetData;
public MysqlSslContext(String protocol) {
this.protocol = protocol;
}
public void init() {
initSslContext();
initSslEngine();
}
private void initSslContext() {
try {
KeyStore ks = KeyStore.getInstance("PKCS12");
KeyStore ts = KeyStore.getInstance("PKCS12");
char[] password = certificatePassword.toCharArray();
ks.load(Files.newInputStream(Paths.get(keyStoreFile)), password);
ts.load(Files.newInputStream(Paths.get(trustStoreFile)), password);
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
kmf.init(ks, password);
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(ts);
sslContext = SSLContext.getInstance(protocol);
sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
} catch (NoSuchAlgorithmException | KeyManagementException | KeyStoreException | IOException
| CertificateException | UnrecoverableKeyException e) {
LOG.fatal("Failed to initialize SSL because", e);
}
}
private void initSslEngine() {
sslEngine = sslContext.createSSLEngine();
// set to server mode
sslEngine.setUseClientMode(false);
sslEngine.setEnabledCipherSuites(sslEngine.getSupportedCipherSuites());
}
public SSLEngine getSslEngine() {
return sslEngine;
}
public String getProtocol() {
return protocol;
}
/*
There may several steps for a successful handshake,
so it's typical to see the following series of operations:
client server message
====== ====== =======
wrap() ... ClientHello
... unwrap() ClientHello
... wrap() ServerHello/Certificate
unwrap() ... ServerHello/Certificate
wrap() ... ClientKeyExchange
wrap() ... ChangeCipherSpec
wrap() ... Finished
... unwrap() ClientKeyExchange
... unwrap() ChangeCipherSpec
... unwrap() Finished
... wrap() ChangeCipherSpec
... wrap() Finished
unwrap() ... ChangeCipherSpec
unwrap() ... Finished
reference: https://docs.oracle.com/javase/10/security/java-secure-socket-extension-jsse-reference-guide.htm#JSSEC-GUID-7FCC21CB-158B-440C-B5E4-E4E5A2D7352B
*/
public boolean sslExchange(MysqlChannel channel) throws Exception {
// long startTime = System.currentTimeMillis();
// init data buffer
initDataBuffer();
// set channel sslengine.
channel.setSslEngine(sslEngine);
// begin handshake.
sslEngine.beginHandshake();
while (sslEngine.getHandshakeStatus() != HandshakeStatus.FINISHED
&& sslEngine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING) {
// if ((System.currentTimeMillis() - startTime) > 10000) {
// throw new Exception("try to establish SSL connection failed, timeout!");
// }
switch (sslEngine.getHandshakeStatus()) {
case NEED_WRAP:
handleNeedWrap(channel);
break;
case NEED_UNWRAP:
handleNeedUnwrap(channel);
break;
case NEED_TASK:
handleNeedTask();
break;
// Under normal circumstances, the following states will not appear
case NOT_HANDSHAKING:
throw new Exception("impossible HandshakeStatus: " + HandshakeStatus.NOT_HANDSHAKING);
case FINISHED:
throw new Exception("impossible HandshakeStatus: " + HandshakeStatus.FINISHED);
default:
throw new IllegalStateException("invalid HandshakeStatus: " + sslEngine.getHandshakeStatus());
}
}
return true;
}
private void initDataBuffer() {
int appLength = sslEngine.getSession().getApplicationBufferSize();
int netLength = sslEngine.getSession().getPacketBufferSize();
serverAppData = clientAppData = ByteBuffer.allocate(appLength);
serverNetData = clientNetData = ByteBuffer.allocate(netLength);
}
private void handleNeedTask() throws Exception {
Runnable runnable;
while ((runnable = sslEngine.getDelegatedTask()) != null) {
runnable.run();
}
HandshakeStatus hsStatus = sslEngine.getHandshakeStatus();
if (hsStatus == HandshakeStatus.NEED_TASK) {
throw new Exception("handshake shouldn't need additional tasks");
}
}
private void handleNeedWrap(MysqlChannel channel) {
try {
while (true) {
SSLEngineResult sslEngineResult = sslEngine.wrap(serverAppData, serverNetData);
if (handleWrapResult(sslEngineResult)) {
// if wrap normal, send packet.
// todo: refactor sendAndFlush.
serverNetData.flip();
channel.sendAndFlush(serverNetData);
serverNetData.clear();
break;
}
// if BUFFER_OVERFLOW, need to wrap again, so we do nothing.
}
} catch (SSLException e) {
sslEngine.closeOutbound();
} catch (IOException e) {
throw new RuntimeException("send failed");
}
}
private void handleNeedUnwrap(MysqlChannel channel) {
try {
// todo: refactor readAll.
clientNetData = channel.fetchOnePacket();
while (true) {
SSLEngineResult sslEngineResult = sslEngine.unwrap(clientNetData, clientAppData);
if (handleUnwrapResult(sslEngineResult)) {
clientAppData.clear();
break;
}
// if BUFFER_OVERFLOW or BUFFER_UNDERFLOW, need to unwrap again, so we do nothing.
}
} catch (IOException e) {
throw new RuntimeException("send failed");
}
}
private boolean handleWrapResult(SSLEngineResult sslEngineResult) throws SSLException {
switch (sslEngineResult.getStatus()) {
// normal status.
case OK:
return true;
case CLOSED:
sslEngine.closeOutbound();
return true;
case BUFFER_OVERFLOW:
// Could attempt to drain the serverNetData buffer of any already obtained
// data, but we'll just increase it to the size needed.
ByteBuffer newBuffer = ByteBuffer.allocate(serverNetData.capacity() * 2);
serverNetData.flip();
newBuffer.put(serverNetData);
serverNetData = newBuffer;
// retry the operation.
return false;
// when wrap BUFFER_UNDERFLOW and other status will not appear.
case BUFFER_UNDERFLOW:
default:
throw new IllegalStateException("invalid wrap status: " + sslEngineResult.getStatus());
}
}
private boolean handleUnwrapResult(SSLEngineResult sslEngineResult) {
switch (sslEngineResult.getStatus()) {
// normal status.
case OK:
return true;
case CLOSED:
sslEngine.closeOutbound();
return true;
case BUFFER_OVERFLOW:
// Could attempt to drain the clientAppData buffer of any already obtained
// data, but we'll just increase it to the size needed.
ByteBuffer newAppBuffer = ByteBuffer.allocate(clientAppData.capacity() * 2);
clientAppData.flip();
newAppBuffer.put(clientAppData);
clientAppData = newAppBuffer;
// retry the operation.
return false;
case BUFFER_UNDERFLOW:
int netSize = sslEngine.getSession().getPacketBufferSize();
// Resize buffer if needed.
if (netSize > clientAppData.capacity()) {
ByteBuffer newNetBuffer = ByteBuffer.allocateDirect(netSize);
clientNetData.flip();
newNetBuffer.put(clientNetData);
clientNetData = newNetBuffer;
}
// Obtain more inbound network data for clientNetData,
// then retry the operation.
return false;
default:
throw new IllegalStateException("invalid wrap status: " + sslEngineResult.getStatus());
}
}
}

View File

@ -0,0 +1,55 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.mysql;
import org.apache.doris.common.Config;
import java.nio.ByteBuffer;
public class MysqlSslPacket extends MysqlPacket {
private int maxPacketSize;
private int characterSet;
private byte[] randomString;
private MysqlCapability capability;
public boolean readFrom(ByteBuffer buffer) {
// read capability four byte, which CLIENT_PROTOCOL_41 must be set
capability = new MysqlCapability(MysqlProto.readInt4(buffer));
if (!capability.isProtocol41()) {
return false;
}
// max packet size
maxPacketSize = MysqlProto.readInt4(buffer);
// character set. only support 33(utf-8)
characterSet = MysqlProto.readInt1(buffer);
// reserved 23 bytes
if (new String(MysqlProto.readFixedString(buffer, 3)).equals(Config.proxy_auth_magic_prefix)) {
randomString = new byte[MysqlPassword.SCRAMBLE_LENGTH];
buffer.get(randomString);
} else {
buffer.position(buffer.position() + 20);
}
return true;
}
@Override
public void writeTo(MysqlSerializer serializer) {
}
}

View File

@ -32,6 +32,7 @@ import org.apache.doris.mysql.DummyMysqlChannel;
import org.apache.doris.mysql.MysqlCapability;
import org.apache.doris.mysql.MysqlChannel;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.mysql.MysqlSslContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.plugin.AuditEvent.AuditEventBuilder;
import org.apache.doris.resource.Tag;
@ -63,6 +64,8 @@ public class ConnectContext {
private static final Logger LOG = LogManager.getLogger(ConnectContext.class);
protected static ThreadLocal<ConnectContext> threadLocalInfo = new ThreadLocal<>();
private static final String SSL_PROTOCOL = "TLS";
// set this id before analyze
protected volatile long stmtId;
protected volatile long forwardedStmtId;
@ -149,6 +152,9 @@ public class ConnectContext {
private SessionContext sessionContext;
// This context is used for SSL connection between server and mysql client.
private final MysqlSslContext mysqlSslContext = new MysqlSslContext(SSL_PROTOCOL);
private long userQueryTimeout;
/**
@ -171,6 +177,10 @@ public class ConnectContext {
return sessionContext;
}
public MysqlSslContext getMysqlSslContext() {
return mysqlSslContext;
}
public void setOrUpdateInsertResult(long txnId, String label, String db, String tbl,
TransactionStatus txnStatus, long loadedRows, int filteredRows) {
if (isTxnModel() && insertResult != null) {

View File

@ -413,7 +413,9 @@ public class ConnectProcessor {
executor.execute();
if (i != stmts.size() - 1) {
ctx.getState().serverStatus |= MysqlServerStatusFlag.SERVER_MORE_RESULTS_EXISTS;
finalizeCommand();
if (ctx.getState().getStateType() != MysqlStateType.ERR) {
finalizeCommand();
}
}
auditAfterExec(auditStmt, executor.getParsedStmt(), executor.getQueryStatisticsForAuditLog());
// execute failed, skip remaining stmts

View File

@ -76,7 +76,8 @@ public class MysqlHandshakePacketTest {
Assert.assertEquals(0, MysqlProto.readInt2(buffer));
// capability flags
flags |= MysqlProto.readInt2(buffer) << 16;
Assert.assertEquals(MysqlCapability.DEFAULT_CAPABILITY.getFlags(), flags);
Assert.assertEquals(MysqlProto.SERVER_USE_SSL
? MysqlCapability.SSL_CAPABILITY.getFlags() : MysqlCapability.DEFAULT_CAPABILITY.getFlags(), flags);
// length of plugin data
Assert.assertEquals(21, MysqlProto.readInt1(buffer));
// length of plugin data