diff --git a/build.sh b/build.sh index 01d35d7880..db4b0cd71e 100755 --- a/build.sh +++ b/build.sh @@ -512,6 +512,7 @@ if [[ "${BUILD_FE}" -eq 1 ]]; then cp -r -p "${DORIS_HOME}/conf/fe.conf" "${DORIS_OUTPUT}/fe/conf"/ cp -r -p "${DORIS_HOME}/conf/ldap.conf" "${DORIS_OUTPUT}/fe/conf"/ cp -r -p "${DORIS_HOME}/conf"/*.xml "${DORIS_OUTPUT}/fe/conf"/ + cp -r -p "${DORIS_HOME}/conf/mysql_ssl_default_certificate" "${DORIS_OUTPUT}/fe/"/ rm -rf "${DORIS_OUTPUT}/fe/lib"/* cp -r -p "${DORIS_HOME}/fe/fe-core/target/lib"/* "${DORIS_OUTPUT}/fe/lib"/ rm -f "${DORIS_OUTPUT}/fe/lib/palo-fe.jar" diff --git a/conf/mysql_ssl_default_certificate/certificate.p12 b/conf/mysql_ssl_default_certificate/certificate.p12 new file mode 100644 index 0000000000..d54fde284b Binary files /dev/null and b/conf/mysql_ssl_default_certificate/certificate.p12 differ diff --git a/fe/fe-common/src/main/java/org/apache/doris/common/Config.java b/fe/fe-common/src/main/java/org/apache/doris/common/Config.java index 840f1237f0..9b6758e057 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/common/Config.java +++ b/fe/fe-common/src/main/java/org/apache/doris/common/Config.java @@ -2004,6 +2004,25 @@ public class Config extends ConfigBase { @ConfField(masterOnly = true, mutable = true) public static int max_error_tablet_of_broker_load = 3; + /** + * If set to ture, doris will establish an encrypted channel based on the SSL protocol with mysql. + */ + @ConfField(mutable = false, masterOnly = false) + public static boolean enable_ssl = false; + + /** + * Default certificate file location for mysql ssl connection. + */ + @ConfField(mutable = false, masterOnly = false) + public static String mysql_ssl_default_certificate = System.getenv("DORIS_HOME") + + "/mysql_ssl_default_certificate/certificate.p12"; + + /** + * Password for default certificate file. + */ + @ConfField(mutable = false, masterOnly = false) + public static String mysql_ssl_default_certificate_password = "doris"; + /** * Used to set session variables randomly to check more issues in github workflow */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/ErrorCode.java b/fe/fe-core/src/main/java/org/apache/doris/common/ErrorCode.java index b65433388b..86ef7b9f26 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/common/ErrorCode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/common/ErrorCode.java @@ -1188,7 +1188,10 @@ public enum ErrorCode { + "the length of table name '%s' is %d which is greater than the configuration 'table_name_length_limit' (%d)."), ERR_NONSUPPORT_TIME_TRAVEL_TABLE(5090, new byte[]{'4', '2', '0', '0', '0'}, "Only iceberg external" - + " table supports time travel in current version"); + + " table supports time travel in current version"), + + ERR_NONSSL_HANDSHAKE_RESPONSE(5091, new byte[] {'4', '2', '0', '0'}, + "SSL mode on but received non-ssl handshake response from client."); // This is error code private final int code; diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/DummyMysqlChannel.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/DummyMysqlChannel.java index fad83a9d56..05b72552f9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/DummyMysqlChannel.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/DummyMysqlChannel.java @@ -50,7 +50,7 @@ public class DummyMysqlChannel extends MysqlChannel { } @Override - protected int readAll(ByteBuffer dstBuf) throws IOException { + protected int readAll(ByteBuffer dstBuf, boolean isHeader) throws IOException { return 0; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlCapability.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlCapability.java index 3984563ca9..52a56e197e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlCapability.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlCapability.java @@ -75,7 +75,14 @@ public class MysqlCapability { private static final int DEFAULT_FLAGS = Flag.CLIENT_PROTOCOL_41.getFlagBit() | Flag.CLIENT_CONNECT_WITH_DB.getFlagBit() | Flag.CLIENT_SECURE_CONNECTION.getFlagBit() | Flag.CLIENT_PLUGIN_AUTH.getFlagBit() | Flag.CLIENT_LOCAL_FILES.getFlagBit(); + + private static final int SSL_FLAGS = Flag.CLIENT_PROTOCOL_41.getFlagBit() + | Flag.CLIENT_CONNECT_WITH_DB.getFlagBit() | Flag.CLIENT_SECURE_CONNECTION.getFlagBit() + | Flag.CLIENT_PLUGIN_AUTH.getFlagBit() | Flag.CLIENT_LOCAL_FILES.getFlagBit() + | Flag.CLIENT_SSL.getFlagBit(); + public static final MysqlCapability DEFAULT_CAPABILITY = new MysqlCapability(DEFAULT_FLAGS); + public static final MysqlCapability SSL_CAPABILITY = new MysqlCapability(SSL_FLAGS); private int flags; @@ -112,6 +119,10 @@ public class MysqlCapability { return (flags & Flag.CLIENT_PROTOCOL_41.getFlagBit()) != 0; } + public boolean isClientUseSsl() { + return (flags & Flag.CLIENT_SSL.getFlagBit()) != 0; + } + public boolean isTransactions() { return (flags & Flag.CLIENT_TRANSACTIONS.getFlagBit()) != 0; diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java index 510a2672ad..9e048a6556 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlChannel.java @@ -23,12 +23,16 @@ import org.apache.doris.qe.ConnectProcessor; import com.google.common.base.Preconditions; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.jetbrains.annotations.NotNull; import org.xnio.StreamConnection; import org.xnio.channels.Channels; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; /** * This class used to read/write MySQL logical packet. @@ -43,6 +47,8 @@ public class MysqlChannel { public static final int MAX_PHYSICAL_PACKET_LENGTH = 0xffffff; // MySQL packet header length protected static final int PACKET_HEADER_LEN = 4; + // SSL packet header length + protected static final int SSL_PACKET_HEADER_LEN = 5; // next sequence id to receive or send protected int sequenceId; // channel connected with client @@ -50,13 +56,23 @@ public class MysqlChannel { // used to receive/send header, avoiding new this many time. protected ByteBuffer headerByteBuffer; protected ByteBuffer defaultBuffer; - // default packet byte buffer for most packet + protected ByteBuffer sslHeaderByteBuffer; + protected ByteBuffer tempBuffer; + protected ByteBuffer remainingBuffer; protected ByteBuffer sendBuffer; + + protected ByteBuffer decryptAppData; + protected ByteBuffer encryptNetData; + // for log and show protected String remoteHostPortString; protected String remoteIp; protected boolean isSend; - // Serializer used to pack MySQL packet. + + protected boolean isSslMode; + protected boolean isSslHandshaking; + private SSLEngine sslEngine; + protected volatile MysqlSerializer serializer; protected MysqlChannel() { @@ -86,6 +102,14 @@ public class MysqlChannel { this.sendBuffer = ByteBuffer.allocate(2 * 1024 * 1024); } + public void initSslBuffer() { + // allocate buffer when needed. + this.remainingBuffer = ByteBuffer.allocate(16 * 1024); + this.remainingBuffer.flip(); + this.tempBuffer = ByteBuffer.allocate(16 * 1024); + this.sslHeaderByteBuffer = ByteBuffer.allocate(SSL_PACKET_HEADER_LEN); + } + public void setSequenceId(int sequenceId) { this.sequenceId = sequenceId; } @@ -94,14 +118,37 @@ public class MysqlChannel { return remoteIp; } + public void setSslEngine(SSLEngine sslEngine) { + this.sslEngine = sslEngine; + decryptAppData = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize() * 2); + encryptNetData = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize() * 2); + } + + public void setSslMode(boolean sslMode) { + isSslMode = sslMode; + if (isSslMode) { + // channel in ssl mode means handshake phase has finished. + isSslHandshaking = false; + } + } + + public void setSslHandshaking(boolean sslHandshaking) { + isSslHandshaking = sslHandshaking; + } + private int packetId() { byte[] header = headerByteBuffer.array(); return header[3] & 0xFF; } - private int packetLen() { - byte[] header = headerByteBuffer.array(); - return (header[0] & 0xFF) | ((header[1] & 0XFF) << 8) | ((header[2] & 0XFF) << 16); + private int packetLen(boolean isSslHeader) { + if (isSslHeader) { + byte[] header = sslHeaderByteBuffer.array(); + return (header[4] & 0xFF) | ((header[3] & 0XFF) << 8); + } else { + byte[] header = headerByteBuffer.array(); + return (header[0] & 0xFF) | ((header[1] & 0XFF) << 8) | ((header[2] & 0XFF) << 16); + } } private void accSequenceId() { @@ -120,17 +167,30 @@ public class MysqlChannel { } } - protected int readAll(ByteBuffer dstBuf) throws IOException { + // all packet header is not encrypted, packet body is not sure. + protected int readAll(ByteBuffer dstBuf, boolean isHeader) throws IOException { int readLen = 0; + if (!dstBuf.hasRemaining()) { + return 0; + } + if (remainingBuffer != null && remainingBuffer.hasRemaining()) { + int oldLen = dstBuf.position(); + while (dstBuf.hasRemaining()) { + dstBuf.put(remainingBuffer.get()); + } + return dstBuf.position() - oldLen; + } try { while (dstBuf.remaining() != 0) { int ret = Channels.readBlocking(conn.getSourceChannel(), dstBuf); // return -1 when remote peer close the channel if (ret == -1) { + decryptData(dstBuf, isHeader); return readLen; } readLen += ret; } + decryptData(dstBuf, isHeader); } catch (IOException e) { LOG.debug("Read channel exception, ignore.", e); return 0; @@ -138,51 +198,130 @@ public class MysqlChannel { return readLen; } + protected void decryptData(ByteBuffer dstBuf, boolean isHeader) throws SSLException { + // after decrypt, we get a mysql packet with mysql header. + if (!isSslMode || isHeader) { + return; + } + dstBuf.flip(); + decryptAppData.clear(); + // unwrap will remove ssl header. + while (true) { + SSLEngineResult result = sslEngine.unwrap(dstBuf, decryptAppData); + if (handleUnwrapResult(result) && !dstBuf.hasRemaining()) { + break; + } + // if BUFFER_OVERFLOW or BUFFER_UNDERFLOW, need to unwrap again, so we do nothing. + } + decryptAppData.flip(); + dstBuf.clear(); + dstBuf.put(decryptAppData); + dstBuf.flip(); + } + // read one logical mysql protocol packet // null for channel is closed. // NOTE: all of the following code is assumed that the channel is in block mode. + // if in handshaking mode we return a packet with header otherwise without header. public ByteBuffer fetchOnePacket() throws IOException { int readLen; ByteBuffer result = defaultBuffer; result.clear(); while (true) { - headerByteBuffer.clear(); - readLen = readAll(headerByteBuffer); - if (readLen != PACKET_HEADER_LEN) { - // remote has close this channel - LOG.debug("Receive packet header failed, remote may close the channel."); - return null; - } - if (packetId() != sequenceId) { - LOG.warn("receive packet sequence id[" + packetId() + "] want to get[" + sequenceId + "]"); - throw new IOException("Bad packet sequence."); - } - int packetLen = packetLen(); - if ((result.capacity() - result.position()) < packetLen) { - // byte buffer is not enough, new one packet - ByteBuffer tmp; - if (packetLen < MAX_PHYSICAL_PACKET_LENGTH) { - // last packet, enough to this packet is OK. - tmp = ByteBuffer.allocate(packetLen + result.position()); - } else { - // already have packet, to allocate two packet. - tmp = ByteBuffer.allocate(2 * packetLen + result.position()); + int packetLen; + // one SSL packet may include multiple Mysql packets, we use remainingBuffer to store them. + if ((isSslMode || isSslHandshaking) && !remainingBuffer.hasRemaining()) { + if (remainingBuffer.position() != 0) { + remainingBuffer.clear(); + remainingBuffer.flip(); } - tmp.put(result.array(), 0, result.position()); - result = tmp; + sslHeaderByteBuffer.clear(); + readLen = readAll(sslHeaderByteBuffer, true); + if (readLen != SSL_PACKET_HEADER_LEN) { + // remote has close this channel + LOG.debug("Receive ssl packet header failed, remote may close the channel."); + return null; + } + // when handshaking and ssl mode, sslengine unwrap need a packet with header. + result.put(sslHeaderByteBuffer.array()); + packetLen = packetLen(true); + } else { + headerByteBuffer.clear(); + readLen = readAll(headerByteBuffer, true); + if (readLen != PACKET_HEADER_LEN) { + // remote has close this channel + LOG.debug("Receive packet header failed, remote may close the channel."); + return null; + } + if (packetId() != sequenceId) { + LOG.warn("receive packet sequence id[" + packetId() + "] want to get[" + sequenceId + "]"); + throw new IOException("Bad packet sequence."); + } + packetLen = packetLen(false); } + result = expandPacket(result, packetLen); // read one physical packet // before read, set limit to make read only one packet result.limit(result.position() + packetLen); - readLen = readAll(result); + readLen = readAll(result, false); + if (isSslMode && remainingBuffer.position() == 0) { + byte[] header = result.array(); + int packetId = header[3] & 0xFF; + if (packetId != sequenceId) { + LOG.warn("receive packet sequence id[" + packetId() + "] want to get[" + sequenceId + "]"); + throw new IOException("Bad packet sequence."); + } + int mysqlPacketLength = (header[0] & 0xFF) | ((header[1] & 0XFF) << 8) | ((header[2] & 0XFF) << 16); + // remove mysql packet header + result.position(4); + result.compact(); + // when encounter large sql query, one mysql packet will be packed as multiple ssl packets. + // we need to read all ssl packets to combine the complete mysql packet. + while (mysqlPacketLength > result.limit()) { + sslHeaderByteBuffer.clear(); + readLen = readAll(sslHeaderByteBuffer, true); + if (readLen != SSL_PACKET_HEADER_LEN) { + // remote has close this channel + LOG.debug("Receive ssl packet header failed, remote may close the channel."); + return null; + } + tempBuffer.clear(); + tempBuffer.put(sslHeaderByteBuffer.array()); + packetLen = packetLen(true); + LOG.info("one ssl packet length is: " + packetLen); + tempBuffer = expandPacket(tempBuffer, packetLen); + result = expandPacket(result, tempBuffer.capacity()); + // read one physical packet + // before read, set limit to make read only one packet + tempBuffer.limit(tempBuffer.position() + packetLen); + readLen = readAll(tempBuffer, false); + result.put(tempBuffer); + result.limit(result.position()); + LOG.info("result is pos: " + result.position() + ", limit: " + + result.limit() + "capacity: " + result.capacity()); + } + if (mysqlPacketLength < result.position()) { + LOG.info("one SSL packet has multiple mysql packets."); + LOG.info("mysql packet length is " + mysqlPacketLength + ", result is pos: " + + result.position() + ", limit: " + result.limit() + "capacity: " + result.capacity()); + result.flip(); + result.position(mysqlPacketLength); + remainingBuffer.clear(); + remainingBuffer.put(result); + remainingBuffer.flip(); + } + result.position(mysqlPacketLength); + } if (readLen != packetLen) { LOG.warn("Length of received packet content(" + readLen + ") is not equal with length in head.(" + packetLen + ")"); return null; } - accSequenceId(); + if (!isSslHandshaking) { + accSequenceId(); + } if (packetLen != MAX_PHYSICAL_PACKET_LENGTH) { result.flip(); break; @@ -191,7 +330,27 @@ public class MysqlChannel { return result; } + @NotNull + private ByteBuffer expandPacket(ByteBuffer result, int packetLen) { + if ((result.capacity() - result.position()) < packetLen) { + // byte buffer is not enough, new one packet + ByteBuffer tmp; + if (packetLen < MAX_PHYSICAL_PACKET_LENGTH) { + // last packet, enough to this packet is OK. + tmp = ByteBuffer.allocate(packetLen + result.position()); + } else { + // already have packet, to allocate two packet. + tmp = ByteBuffer.allocate(2 * packetLen + result.position()); + } + tmp.put(result.array(), 0, result.position()); + result = tmp; + } + result.limit(result.position() + packetLen); + return result; + } + protected void realNetSend(ByteBuffer buffer) throws IOException { + encryptData(buffer); long bufLen = buffer.remaining(); long writeLen = Channels.writeBlocking(conn.getSinkChannel(), buffer); if (bufLen != writeLen) { @@ -202,6 +361,23 @@ public class MysqlChannel { isSend = true; } + protected void encryptData(ByteBuffer dstBuf) throws SSLException { + if (!isSslMode) { + return; + } + encryptNetData.clear(); + while (true) { + SSLEngineResult result = sslEngine.wrap(dstBuf, encryptNetData); + if (handleWrapResult(result) && !dstBuf.hasRemaining()) { + break; + } + } + encryptNetData.flip(); + dstBuf.clear(); + dstBuf.put(encryptNetData); + dstBuf.flip(); + } + public void flush() throws IOException { if (null == sendBuffer || sendBuffer.position() == 0) { // Nothing to send @@ -213,7 +389,7 @@ public class MysqlChannel { isSend = true; } - private void writeHeader(int length) throws IOException { + private void writeHeader(int length, boolean isSsl) throws IOException { if (null == sendBuffer) { return; } @@ -230,7 +406,7 @@ public class MysqlChannel { sendBuffer.put((byte) sequenceId); } - private void writeBuffer(ByteBuffer buffer) throws IOException { + private void writeBuffer(ByteBuffer buffer, boolean isSsl) throws IOException { if (null == sendBuffer) { return; } @@ -250,19 +426,30 @@ public class MysqlChannel { } public void sendOnePacket(ByteBuffer packet) throws IOException { + // handshake in packet with header and has encrypted, need to send in ssl format + // ssl mode in packet no header and no encrypted, need to encrypted and add header and send in ssl format int bufLen; int oldLimit = packet.limit(); while (oldLimit - packet.position() >= MAX_PHYSICAL_PACKET_LENGTH) { bufLen = MAX_PHYSICAL_PACKET_LENGTH; packet.limit(packet.position() + bufLen); - writeHeader(bufLen); - writeBuffer(packet); + if (isSslHandshaking) { + writeBuffer(packet, true); + } else { + writeHeader(bufLen, isSslMode); + writeBuffer(packet, isSslMode); + accSequenceId(); + } + } + if (isSslHandshaking) { + packet.limit(oldLimit); + writeBuffer(packet, true); + } else { + writeHeader(oldLimit - packet.position(), isSslMode); + packet.limit(oldLimit); + writeBuffer(packet, isSslMode); accSequenceId(); } - writeHeader(oldLimit - packet.position()); - packet.limit(oldLimit); - writeBuffer(packet); - accSequenceId(); } public void sendAndFlush(ByteBuffer packet) throws IOException { @@ -306,4 +493,52 @@ public class MysqlChannel { public MysqlSerializer getSerializer() { return serializer; } + + private boolean handleWrapResult(SSLEngineResult sslEngineResult) throws SSLException { + switch (sslEngineResult.getStatus()) { + // normal status. + case OK: + return true; + case CLOSED: + sslEngine.closeOutbound(); + return true; + case BUFFER_OVERFLOW: + // Could attempt to drain the serverNetData buffer of any already obtained + // data, but we'll just increase it to the size needed. + ByteBuffer newBuffer = ByteBuffer.allocate(encryptNetData.capacity() * 2); + encryptNetData.flip(); + newBuffer.put(encryptNetData); + encryptNetData = newBuffer; + // retry the operation. + return false; + // when wrap BUFFER_UNDERFLOW and other status will not appear. + case BUFFER_UNDERFLOW: + default: + throw new IllegalStateException("invalid wrap status: " + sslEngineResult.getStatus()); + } + } + + private boolean handleUnwrapResult(SSLEngineResult sslEngineResult) { + switch (sslEngineResult.getStatus()) { + // normal status. + case OK: + return true; + case CLOSED: + sslEngine.closeOutbound(); + return true; + case BUFFER_OVERFLOW: + // Could attempt to drain the clientAppData buffer of any already obtained + // data, but we'll just increase it to the size needed. + ByteBuffer newAppBuffer = ByteBuffer.allocate(decryptAppData.capacity() * 2); + decryptAppData.flip(); + newAppBuffer.put(decryptAppData); + decryptAppData = newAppBuffer; + // retry the operation. + return false; + case BUFFER_UNDERFLOW: + default: + throw new IllegalStateException("invalid wrap status: " + sslEngineResult.getStatus()); + } + } + } diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlHandshakePacket.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlHandshakePacket.java index 209f9b81cf..c2ba21a23e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlHandshakePacket.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlHandshakePacket.java @@ -30,6 +30,7 @@ public class MysqlHandshakePacket extends MysqlPacket { private static final int CHARACTER_SET = 33; // use default capability for all private static final MysqlCapability CAPABILITY = MysqlCapability.DEFAULT_CAPABILITY; + private static final MysqlCapability SSL_CAPABILITY = MysqlCapability.SSL_CAPABILITY; // status flags not supported in palo private static final int STATUS_FLAGS = 0; private static final String AUTH_PLUGIN_NAME = "mysql_native_password"; @@ -49,7 +50,7 @@ public class MysqlHandshakePacket extends MysqlPacket { @Override public void writeTo(MysqlSerializer serializer) { - MysqlCapability capability = CAPABILITY; + MysqlCapability capability = MysqlProto.SERVER_USE_SSL ? SSL_CAPABILITY : CAPABILITY; serializer.writeInt1(PROTOCOL_VERSION); serializer.writeNulTerminateString(SERVER_VERSION); diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlProto.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlProto.java index d5635b1afd..d51dee1924 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlProto.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlProto.java @@ -45,6 +45,7 @@ import java.util.List; // MySQL protocol util public class MysqlProto { private static final Logger LOG = LogManager.getLogger(MysqlProto.class); + public static final boolean SERVER_USE_SSL = Config.enable_ssl; // scramble: data receive from server. // randomString: data send by server in plug-in data field @@ -170,8 +171,68 @@ public class MysqlProto { LOG.debug("Send and flush channel exception, ignore.", e); return false; } + + // Server receive request packet from client, we need to determine which request type it is. + ByteBuffer clientRequestPacket = channel.fetchOnePacket(); + MysqlCapability capability = new MysqlCapability(MysqlProto.readLowestInt4(clientRequestPacket)); + + // Server receive SSL connection request packet from client. + ByteBuffer sslConnectionRequest; // Server receive authenticate packet from client. - ByteBuffer handshakeResponse = channel.fetchOnePacket(); + ByteBuffer handshakeResponse; + + if (capability.isClientUseSsl()) { + LOG.info("client is using ssl connection."); + // During development, we set SSL mode to true by default. + if (SERVER_USE_SSL) { + LOG.info("server is also using ssl connection. Will use ssl mode for data exchange."); + MysqlSslContext mysqlSslContext = context.getMysqlSslContext(); + mysqlSslContext.init(); + channel.initSslBuffer(); + sslConnectionRequest = clientRequestPacket; + if (sslConnectionRequest == null) { + // receive response failed. + return false; + } + MysqlSslPacket sslPacket = new MysqlSslPacket(); + if (!sslPacket.readFrom(sslConnectionRequest)) { + ErrorReport.report(ErrorCode.ERR_NOT_SUPPORTED_AUTH_MODE); + sendResponsePacket(context); + return false; + } + // try to establish ssl connection. + try { + // set channel to handshake mode to process data packet as ssl packet. + channel.setSslHandshaking(true); + // The ssl handshake phase still uses plaintext. + if (!mysqlSslContext.sslExchange(channel)) { + ErrorReport.report(ErrorCode.ERR_NOT_SUPPORTED_AUTH_MODE); + sendResponsePacket(context); + return false; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + // if the exchange is successful, the channel will switch to ssl communication mode + // which means all data after this moment will be ciphertext. + + // Set channel mode to ssl mode to handle socket packet in ssl format. + channel.setSslMode(true); + LOG.info("switch to ssl mode."); + handshakeResponse = channel.fetchOnePacket(); + capability = new MysqlCapability(MysqlProto.readLowestInt4(handshakeResponse)); + if (!capability.isClientUseSsl()) { + ErrorReport.report(ErrorCode.ERR_NONSSL_HANDSHAKE_RESPONSE); + sendResponsePacket(context); + return false; + } + } else { + handshakeResponse = clientRequestPacket; + } + } else { + handshakeResponse = clientRequestPacket; + } + if (handshakeResponse == null) { // receive response failed. return false; @@ -324,6 +385,10 @@ public class MysqlProto { return buffer.get(); } + public static byte readByteAt(ByteBuffer buffer, int index) { + return buffer.get(index); + } + public static int readInt1(ByteBuffer buffer) { return readByte(buffer) & 0XFF; } @@ -337,6 +402,11 @@ public class MysqlProto { buffer) & 0xFF) << 16); } + public static int readLowestInt4(ByteBuffer buffer) { + return (readByteAt(buffer, 0) & 0xFF) | ((readByteAt(buffer, 1) & 0xFF) << 8) | ((readByteAt( + buffer, 2) & 0xFF) << 16) | ((readByteAt(buffer, 3) & 0XFF) << 24); + } + public static int readInt4(ByteBuffer buffer) { return (readByte(buffer) & 0xFF) | ((readByte(buffer) & 0xFF) << 8) | ((readByte( buffer) & 0xFF) << 16) | ((readByte(buffer) & 0XFF) << 24); diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlSslContext.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlSslContext.java new file mode 100644 index 0000000000..3aa7dd45a7 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlSslContext.java @@ -0,0 +1,277 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.mysql; + +import org.apache.doris.common.Config; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.KeyManagementException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; + +public class MysqlSslContext { + + private static final Logger LOG = LogManager.getLogger(MysqlSslContext.class); + private SSLEngine sslEngine; + private SSLContext sslContext; + private String protocol; + private ByteBuffer serverAppData; + private static final String keyStoreFile = Config.mysql_ssl_default_certificate; + private static final String trustStoreFile = Config.mysql_ssl_default_certificate; + private static final String certificatePassword = Config.mysql_ssl_default_certificate_password; + private ByteBuffer serverNetData; + private ByteBuffer clientAppData; + private ByteBuffer clientNetData; + + public MysqlSslContext(String protocol) { + this.protocol = protocol; + } + + public void init() { + initSslContext(); + initSslEngine(); + } + + private void initSslContext() { + try { + KeyStore ks = KeyStore.getInstance("PKCS12"); + KeyStore ts = KeyStore.getInstance("PKCS12"); + + char[] password = certificatePassword.toCharArray(); + + ks.load(Files.newInputStream(Paths.get(keyStoreFile)), password); + ts.load(Files.newInputStream(Paths.get(trustStoreFile)), password); + + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, password); + + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(ts); + sslContext = SSLContext.getInstance(protocol); + sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } catch (NoSuchAlgorithmException | KeyManagementException | KeyStoreException | IOException + | CertificateException | UnrecoverableKeyException e) { + LOG.fatal("Failed to initialize SSL because", e); + } + } + + private void initSslEngine() { + sslEngine = sslContext.createSSLEngine(); + // set to server mode + sslEngine.setUseClientMode(false); + sslEngine.setEnabledCipherSuites(sslEngine.getSupportedCipherSuites()); + } + + public SSLEngine getSslEngine() { + return sslEngine; + } + + public String getProtocol() { + return protocol; + } + + + /* + There may several steps for a successful handshake, + so it's typical to see the following series of operations: + + client server message + ====== ====== ======= + wrap() ... ClientHello + ... unwrap() ClientHello + ... wrap() ServerHello/Certificate + unwrap() ... ServerHello/Certificate + wrap() ... ClientKeyExchange + wrap() ... ChangeCipherSpec + wrap() ... Finished + ... unwrap() ClientKeyExchange + ... unwrap() ChangeCipherSpec + ... unwrap() Finished + ... wrap() ChangeCipherSpec + ... wrap() Finished + unwrap() ... ChangeCipherSpec + unwrap() ... Finished + reference: https://docs.oracle.com/javase/10/security/java-secure-socket-extension-jsse-reference-guide.htm#JSSEC-GUID-7FCC21CB-158B-440C-B5E4-E4E5A2D7352B + */ + public boolean sslExchange(MysqlChannel channel) throws Exception { + // long startTime = System.currentTimeMillis(); + // init data buffer + initDataBuffer(); + // set channel sslengine. + channel.setSslEngine(sslEngine); + // begin handshake. + sslEngine.beginHandshake(); + while (sslEngine.getHandshakeStatus() != HandshakeStatus.FINISHED + && sslEngine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING) { + // if ((System.currentTimeMillis() - startTime) > 10000) { + // throw new Exception("try to establish SSL connection failed, timeout!"); + // } + switch (sslEngine.getHandshakeStatus()) { + case NEED_WRAP: + handleNeedWrap(channel); + break; + case NEED_UNWRAP: + handleNeedUnwrap(channel); + break; + case NEED_TASK: + handleNeedTask(); + break; + // Under normal circumstances, the following states will not appear + case NOT_HANDSHAKING: + throw new Exception("impossible HandshakeStatus: " + HandshakeStatus.NOT_HANDSHAKING); + case FINISHED: + throw new Exception("impossible HandshakeStatus: " + HandshakeStatus.FINISHED); + default: + throw new IllegalStateException("invalid HandshakeStatus: " + sslEngine.getHandshakeStatus()); + } + } + return true; + } + + private void initDataBuffer() { + int appLength = sslEngine.getSession().getApplicationBufferSize(); + int netLength = sslEngine.getSession().getPacketBufferSize(); + serverAppData = clientAppData = ByteBuffer.allocate(appLength); + serverNetData = clientNetData = ByteBuffer.allocate(netLength); + } + + private void handleNeedTask() throws Exception { + Runnable runnable; + while ((runnable = sslEngine.getDelegatedTask()) != null) { + runnable.run(); + } + HandshakeStatus hsStatus = sslEngine.getHandshakeStatus(); + if (hsStatus == HandshakeStatus.NEED_TASK) { + throw new Exception("handshake shouldn't need additional tasks"); + } + } + + private void handleNeedWrap(MysqlChannel channel) { + try { + while (true) { + SSLEngineResult sslEngineResult = sslEngine.wrap(serverAppData, serverNetData); + if (handleWrapResult(sslEngineResult)) { + // if wrap normal, send packet. + // todo: refactor sendAndFlush. + serverNetData.flip(); + channel.sendAndFlush(serverNetData); + serverNetData.clear(); + break; + } + // if BUFFER_OVERFLOW, need to wrap again, so we do nothing. + } + } catch (SSLException e) { + sslEngine.closeOutbound(); + } catch (IOException e) { + throw new RuntimeException("send failed"); + } + } + + private void handleNeedUnwrap(MysqlChannel channel) { + try { + // todo: refactor readAll. + clientNetData = channel.fetchOnePacket(); + while (true) { + SSLEngineResult sslEngineResult = sslEngine.unwrap(clientNetData, clientAppData); + if (handleUnwrapResult(sslEngineResult)) { + clientAppData.clear(); + break; + } + // if BUFFER_OVERFLOW or BUFFER_UNDERFLOW, need to unwrap again, so we do nothing. + } + } catch (IOException e) { + throw new RuntimeException("send failed"); + } + } + + + private boolean handleWrapResult(SSLEngineResult sslEngineResult) throws SSLException { + switch (sslEngineResult.getStatus()) { + // normal status. + case OK: + return true; + case CLOSED: + sslEngine.closeOutbound(); + return true; + case BUFFER_OVERFLOW: + // Could attempt to drain the serverNetData buffer of any already obtained + // data, but we'll just increase it to the size needed. + ByteBuffer newBuffer = ByteBuffer.allocate(serverNetData.capacity() * 2); + serverNetData.flip(); + newBuffer.put(serverNetData); + serverNetData = newBuffer; + // retry the operation. + return false; + // when wrap BUFFER_UNDERFLOW and other status will not appear. + case BUFFER_UNDERFLOW: + default: + throw new IllegalStateException("invalid wrap status: " + sslEngineResult.getStatus()); + } + } + + private boolean handleUnwrapResult(SSLEngineResult sslEngineResult) { + switch (sslEngineResult.getStatus()) { + // normal status. + case OK: + return true; + case CLOSED: + sslEngine.closeOutbound(); + return true; + case BUFFER_OVERFLOW: + // Could attempt to drain the clientAppData buffer of any already obtained + // data, but we'll just increase it to the size needed. + ByteBuffer newAppBuffer = ByteBuffer.allocate(clientAppData.capacity() * 2); + clientAppData.flip(); + newAppBuffer.put(clientAppData); + clientAppData = newAppBuffer; + // retry the operation. + return false; + case BUFFER_UNDERFLOW: + int netSize = sslEngine.getSession().getPacketBufferSize(); + // Resize buffer if needed. + if (netSize > clientAppData.capacity()) { + ByteBuffer newNetBuffer = ByteBuffer.allocateDirect(netSize); + clientNetData.flip(); + newNetBuffer.put(clientNetData); + clientNetData = newNetBuffer; + } + // Obtain more inbound network data for clientNetData, + // then retry the operation. + return false; + default: + throw new IllegalStateException("invalid wrap status: " + sslEngineResult.getStatus()); + } + + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlSslPacket.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlSslPacket.java new file mode 100644 index 0000000000..5fbb843e4b --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/MysqlSslPacket.java @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.mysql; + +import org.apache.doris.common.Config; + +import java.nio.ByteBuffer; + +public class MysqlSslPacket extends MysqlPacket { + + private int maxPacketSize; + private int characterSet; + private byte[] randomString; + private MysqlCapability capability; + + public boolean readFrom(ByteBuffer buffer) { + // read capability four byte, which CLIENT_PROTOCOL_41 must be set + capability = new MysqlCapability(MysqlProto.readInt4(buffer)); + if (!capability.isProtocol41()) { + return false; + } + // max packet size + maxPacketSize = MysqlProto.readInt4(buffer); + // character set. only support 33(utf-8) + characterSet = MysqlProto.readInt1(buffer); + // reserved 23 bytes + if (new String(MysqlProto.readFixedString(buffer, 3)).equals(Config.proxy_auth_magic_prefix)) { + randomString = new byte[MysqlPassword.SCRAMBLE_LENGTH]; + buffer.get(randomString); + } else { + buffer.position(buffer.position() + 20); + } + return true; + } + + @Override + public void writeTo(MysqlSerializer serializer) { + + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java index 1a5143a5f6..1598ecbdbc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java @@ -32,6 +32,7 @@ import org.apache.doris.mysql.DummyMysqlChannel; import org.apache.doris.mysql.MysqlCapability; import org.apache.doris.mysql.MysqlChannel; import org.apache.doris.mysql.MysqlCommand; +import org.apache.doris.mysql.MysqlSslContext; import org.apache.doris.nereids.StatementContext; import org.apache.doris.plugin.AuditEvent.AuditEventBuilder; import org.apache.doris.resource.Tag; @@ -63,6 +64,8 @@ public class ConnectContext { private static final Logger LOG = LogManager.getLogger(ConnectContext.class); protected static ThreadLocal threadLocalInfo = new ThreadLocal<>(); + private static final String SSL_PROTOCOL = "TLS"; + // set this id before analyze protected volatile long stmtId; protected volatile long forwardedStmtId; @@ -149,6 +152,9 @@ public class ConnectContext { private SessionContext sessionContext; + // This context is used for SSL connection between server and mysql client. + private final MysqlSslContext mysqlSslContext = new MysqlSslContext(SSL_PROTOCOL); + private long userQueryTimeout; /** @@ -171,6 +177,10 @@ public class ConnectContext { return sessionContext; } + public MysqlSslContext getMysqlSslContext() { + return mysqlSslContext; + } + public void setOrUpdateInsertResult(long txnId, String label, String db, String tbl, TransactionStatus txnStatus, long loadedRows, int filteredRows) { if (isTxnModel() && insertResult != null) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java index b45564ec74..767ea4f732 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java @@ -413,7 +413,9 @@ public class ConnectProcessor { executor.execute(); if (i != stmts.size() - 1) { ctx.getState().serverStatus |= MysqlServerStatusFlag.SERVER_MORE_RESULTS_EXISTS; - finalizeCommand(); + if (ctx.getState().getStateType() != MysqlStateType.ERR) { + finalizeCommand(); + } } auditAfterExec(auditStmt, executor.getParsedStmt(), executor.getQueryStatisticsForAuditLog()); // execute failed, skip remaining stmts diff --git a/fe/fe-core/src/test/java/org/apache/doris/mysql/MysqlHandshakePacketTest.java b/fe/fe-core/src/test/java/org/apache/doris/mysql/MysqlHandshakePacketTest.java index 7637a8de66..04516ca420 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/mysql/MysqlHandshakePacketTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/mysql/MysqlHandshakePacketTest.java @@ -76,7 +76,8 @@ public class MysqlHandshakePacketTest { Assert.assertEquals(0, MysqlProto.readInt2(buffer)); // capability flags flags |= MysqlProto.readInt2(buffer) << 16; - Assert.assertEquals(MysqlCapability.DEFAULT_CAPABILITY.getFlags(), flags); + Assert.assertEquals(MysqlProto.SERVER_USE_SSL + ? MysqlCapability.SSL_CAPABILITY.getFlags() : MysqlCapability.DEFAULT_CAPABILITY.getFlags(), flags); // length of plugin data Assert.assertEquals(21, MysqlProto.readInt1(buffer)); // length of plugin data diff --git a/regression-test/certificate.p12 b/regression-test/certificate.p12 new file mode 100644 index 0000000000..d54fde284b Binary files /dev/null and b/regression-test/certificate.p12 differ diff --git a/regression-test/framework/src/main/groovy/org/apache/doris/regression/Config.groovy b/regression-test/framework/src/main/groovy/org/apache/doris/regression/Config.groovy index 6c36000029..027de85c7d 100644 --- a/regression-test/framework/src/main/groovy/org/apache/doris/regression/Config.groovy +++ b/regression-test/framework/src/main/groovy/org/apache/doris/regression/Config.groovy @@ -31,6 +31,7 @@ import java.sql.DriverManager import java.util.concurrent.atomic.AtomicReference import java.util.function.Predicate +import static java.lang.Math.random import static org.apache.doris.regression.ConfigOptions.* @Slf4j @@ -371,7 +372,7 @@ class Config { log.info("Set actionParallel to 10 because not specify.".toString()) } } - + static String configToString(Object obj) { return (obj instanceof String || obj instanceof GString) ? obj.toString() : null } @@ -465,7 +466,8 @@ class Config { if (urlWithoutSchema.indexOf("/") >= 0) { if (jdbcUrl.contains("?")) { // e.g: jdbc:mysql://locahost:8080/?a=b - urlWithDb = jdbcUrl.substring(0, jdbcUrl.lastIndexOf("/")) + urlWithDb = jdbcUrl.substring(0, jdbcUrl.lastIndexOf("?")) + urlWithDb = urlWithDb.substring(0, urlWithDb.lastIndexOf("/")) urlWithDb += ("/" + dbName) + jdbcUrl.substring(jdbcUrl.lastIndexOf("?")) } else { // e.g: jdbc:mysql://locahost:8080/ @@ -475,7 +477,33 @@ class Config { // e.g: jdbc:mysql://locahost:8080 urlWithDb += ("/" + dbName) } + urlWithDb = addSslUrl(urlWithDb); return urlWithDb } + + private String addSslUrl(String url) { + if (url.contains("TLS")) { + return url + } + // ssl-mode = PREFERRED + String useSsl = "true" + String useSslConfig = "verifyServerCertificate=false&useSSL=" + useSsl + "&requireSSL=false" + String tlsVersion = "TLSv1.2" + String tlsVersionConfig = "&enabledTLSProtocols=" + tlsVersion + String keyStoreFile = "file:regression-test/certificate.p12" + String keyStoreFileConfig = "&trustCertificateKeyStoreUrl=" + keyStoreFile + "&clientCertificateKeyStoreUrl=" + keyStoreFile + String password = "&trustCertificateKeyStorePassword=doris&clientCertificateKeyStorePassword=doris" + String sslUrl = useSslConfig + tlsVersionConfig + keyStoreFileConfig + password + // e.g: jdbc:mysql://locahost:8080/dbname? + if (url.charAt(url.length() - 1) == '?') { + return url + sslUrl + // e.g: jdbc:mysql://locahost:8080/dbname?a=b + } else if (url.contains('?')) { + return url + '&' + sslUrl + // e.g: jdbc:mysql://locahost:8080/dbname + } else { + return url + '?' + sslUrl + } + } } diff --git a/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/Suite.groovy b/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/Suite.groovy index c4a1eaf23e..2d2acb2b45 100644 --- a/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/Suite.groovy +++ b/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/Suite.groovy @@ -519,5 +519,6 @@ class Suite implements GroovyInterceptable { return metaClass.invokeMethod(this, name, args) } } + } diff --git a/regression-test/pipeline/p0/conf/fe.conf b/regression-test/pipeline/p0/conf/fe.conf index 850a4f78ef..e8d9847bf8 100644 --- a/regression-test/pipeline/p0/conf/fe.conf +++ b/regression-test/pipeline/p0/conf/fe.conf @@ -65,6 +65,9 @@ sys_log_verbose_modules = org.apache.doris # qe_slow_log_ms = 5000 # +// enable ssl for test +enable_ssl = true + enable_outfile_to_local = true tablet_create_timeout_second=20 remote_fragment_exec_timeout_ms=60000 diff --git a/regression-test/pipeline/p1/conf/fe.conf b/regression-test/pipeline/p1/conf/fe.conf index 6f737da817..d4f0608e98 100644 --- a/regression-test/pipeline/p1/conf/fe.conf +++ b/regression-test/pipeline/p1/conf/fe.conf @@ -66,6 +66,9 @@ priority_networks=172.19.0.0/24 # qe_slow_log_ms = 5000 # +// enable ssl for test +enable_ssl = true + enable_outfile_to_local = true tablet_create_timeout_second=20 remote_fragment_exec_timeout_ms=60000 diff --git a/regression-test/pipeline/p1/conf/regression-conf.groovy b/regression-test/pipeline/p1/conf/regression-conf.groovy index 455e3d8267..acf8f7788d 100644 --- a/regression-test/pipeline/p1/conf/regression-conf.groovy +++ b/regression-test/pipeline/p1/conf/regression-conf.groovy @@ -45,4 +45,4 @@ cacheDataPath="/data/regression/" s3Endpoint = "cos.ap-hongkong.myqcloud.com" s3BucketName = "doris-build-hk-1308700295" -s3Region = "ap-hongkong" \ No newline at end of file +s3Region = "ap-hongkong" diff --git a/regression-test/suites/mysql_ssl_p0/test_mysql_connection.groovy b/regression-test/suites/mysql_ssl_p0/test_mysql_connection.groovy new file mode 100644 index 0000000000..25e00d1b47 --- /dev/null +++ b/regression-test/suites/mysql_ssl_p0/test_mysql_connection.groovy @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_mysql_connection") { + + def executeMySQLCommand = { String command -> + try { + String line; + StringBuilder errMsg = new StringBuilder(); + StringBuilder msg = new StringBuilder(); + Process p = Runtime.getRuntime().exec(new String[]{"/bin/bash", "-c", command}); + + BufferedReader errInput = new BufferedReader(new InputStreamReader(p.getErrorStream())); + while ((line = errInput.readLine()) != null) { + errMsg.append(line); + } + assert errMsg.length() == 0: "error occurred!" + errMsg.toString(); + errInput.close(); + + BufferedReader input = new BufferedReader(new InputStreamReader(p.getInputStream())); + while ((line = input.readLine()) != null) { + msg.append(line); + } + assert msg.toString().contains("version"): "error occurred!" + errMsg.toString(); + input.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + String jdbcUrlConfig = context.config.jdbcUrl; + String tempString = jdbcUrlConfig.substring(jdbcUrlConfig.indexOf("jdbc:mysql://") + 13); + String mysqlHost = tempString.substring(0, tempString.indexOf(":")); + String mysqlPort = tempString.substring(tempString.indexOf(":") + 1, tempString.indexOf("/")); + String cmdDefault = "mysql -uroot -h" + mysqlHost + " -P" + mysqlPort + " -e \"show variables\""; + String cmdDisabledSsl = "mysql --ssl-mode=DISABLE -uroot -h" + mysqlHost + " -P" + mysqlPort + " -e \"show variables\""; + String cmdSsl12 = "mysql --ssl-mode=REQUIRED -uroot -h" + mysqlHost + " -P" + mysqlPort + " --tls-version=TLSv1.2 -e \"show variables\""; + // The current mysql-client version of the test environment is 5.7.32, which does not support TLSv1.3, so comment this part. + // String cmdSsl13 = "mysql --ssl-mode=REQUIRED -uroot -h" + mysqlHost + " -P" + mysqlPort + " --tls-version=TLSv1.3 -e \"show variables\""; + executeMySQLCommand(cmdDefault); + executeMySQLCommand(cmdDisabledSsl); + executeMySQLCommand(cmdSsl12); + // executeMySQLCommand(cmdSsl13); +} diff --git a/regression-test/suites/mysql_ssl_p0/test_ssl_stability.groovy b/regression-test/suites/mysql_ssl_p0/test_ssl_stability.groovy new file mode 100644 index 0000000000..3cd75f4f8d --- /dev/null +++ b/regression-test/suites/mysql_ssl_p0/test_ssl_stability.groovy @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_ssl_stability") { + def tbName = "tb_test_ssl_stability" + int test_count = 5; + while (test_count-- > 1) { + sql "DROP TABLE IF EXISTS ${tbName}" + // char not null to null + sql """ + CREATE TABLE IF NOT EXISTS ${tbName} ( + k1 INT NOT NULL, + value1 varchar(16) NOT NULL + ) + DUPLICATE KEY (k1) + DISTRIBUTED BY HASH(k1) BUCKETS 1 properties("replication_num" = "1"); + """ + StringBuilder insertCommand = new StringBuilder(); + insertCommand.append("INSERT INTO ${tbName} VALUES "); + int insert_row_count = 100000; + while (insert_row_count-- > 1) { + insertCommand.append("(1, '11'),"); + } + insertCommand.append("(1, '11')"); + sql insertCommand.toString() + } +}