[feature](sql-dialect) support convert sql use sql convertor service (#27581)

Add a new FE Config `sql_convertor_service`.
If this config is set, and the session variable `sql_dialect` is set,
Doris will try to use a standalone sql converter service to convert user input sql to
specified sql dialect. eg:

```
mysql> set sql_dialect="presto";
Query OK, 0 rows affected (0.02 sec)

Database changed
mysql> select * from db1.tbl1 where "k1" = 1;  # will be converted to select * from db1.tbl1 where `k1` = 1;
+------+------+
| k1   | k2   |
+------+------+
|    1 |    2 |
+------+------+
1 row in set (0.08 sec)
```

The sql converter service should be a http service.
The request and response body can be found in `SQLDialectUtils.java`
This commit is contained in:
Mingyu Chen
2023-12-18 10:32:52 +08:00
committed by GitHub
parent d11365da9c
commit 6e855dd198
8 changed files with 406 additions and 12 deletions

View File

@ -2331,4 +2331,8 @@ public class Config extends ConfigBase {
@ConfField(description = {"是否开启通过http接口获取log文件的功能",
"Whether to enable the function of getting log files through http interface"})
public static boolean enable_get_log_file_api = false;
@ConfField(description = {"用于SQL方言转换的服务地址。",
"The service address for SQL dialect conversion."})
public static String sql_convertor_service = "";
}

View File

@ -0,0 +1,183 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.common.util;
import org.apache.doris.common.Config;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.parser.ParseDialect;
import org.apache.doris.qe.ConnectContext;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import lombok.Data;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.lang.reflect.Type;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
/**
* This class is used to convert sql with different dialects
* using sql convertor service.
* The sql convertor service is a http service which is used to convert sql.
* Request body:
* {
* "version": "v1",
* "sql": "select * from t",
* "from": "presto",
* "to": "doris",
* "source": "text",
* "case_sensitive": "0"
* }
* <p>
* Response body:
* {
* "version": "v1",
* "data": "select * from t",
* "code": 0,
* "message": ""
*/
public class SQLDialectUtils {
private static final Logger LOG = LogManager.getLogger(SQLDialectUtils.class);
public static String convertStmtWithDialect(String originStmt, ConnectContext ctx, MysqlCommand mysqlCommand) {
if (mysqlCommand != MysqlCommand.COM_QUERY) {
return originStmt;
}
if (Config.sql_convertor_service.isEmpty()) {
return originStmt;
}
ParseDialect.Dialect dialect = ctx.getSessionVariable().getSqlParseDialect();
if (dialect == null) {
return originStmt;
}
switch (dialect) {
case PRESTO:
return convertStmtWithPresto(originStmt);
default:
LOG.debug("only support presto dialect now.");
return originStmt;
}
}
private static String convertStmtWithPresto(String originStmt) {
String targetURL = Config.sql_convertor_service;
ConvertRequest convertRequest = new ConvertRequest(originStmt, "presto");
HttpURLConnection connection = null;
try {
URL url = new URL(targetURL);
connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
connection.setRequestProperty("Content-Type", "application/json");
connection.setUseCaches(false);
connection.setDoOutput(true);
String requestStr = convertRequest.toJson();
try (OutputStream outputStream = connection.getOutputStream()) {
outputStream.write(requestStr.getBytes(StandardCharsets.UTF_8));
}
int responseCode = connection.getResponseCode();
LOG.debug("POST Response Code: {}, post data: {}", responseCode, requestStr);
if (responseCode == HttpURLConnection.HTTP_OK) {
try (InputStreamReader inputStreamReader
= new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8);
BufferedReader in = new BufferedReader(inputStreamReader)) {
String inputLine;
StringBuilder response = new StringBuilder();
while ((inputLine = in.readLine()) != null) {
response.append(inputLine);
}
Type type = new TypeToken<ConvertResponse>() {
}.getType();
ConvertResponse result = new Gson().fromJson(response.toString(), type);
LOG.debug("convert response: {}", result);
if (result.code == 0) {
if (!"v1".equals(result.version)) {
LOG.warn("failed to convert sql, response version is not v1: {}", result.version);
return originStmt;
}
return result.data;
} else {
LOG.warn("failed to convert sql, response: {}", result);
return originStmt;
}
}
} else {
LOG.warn("failed to convert sql, response code: {}", responseCode);
return originStmt;
}
} catch (Exception e) {
LOG.warn("failed to convert sql", e);
return originStmt;
} finally {
if (connection != null) {
connection.disconnect();
}
}
}
@Data
private static class ConvertRequest {
private String version; // CHECKSTYLE IGNORE THIS LINE
private String sql_query; // CHECKSTYLE IGNORE THIS LINE
private String from; // CHECKSTYLE IGNORE THIS LINE
private String to; // CHECKSTYLE IGNORE THIS LINE
private String source; // CHECKSTYLE IGNORE THIS LINE
private String case_sensitive; // CHECKSTYLE IGNORE THIS LINE
public ConvertRequest(String originStmt, String dialect) {
this.version = "v1";
this.sql_query = originStmt;
this.from = dialect;
this.to = "doris";
this.source = "text";
this.case_sensitive = "0";
}
public String toJson() {
return new Gson().toJson(this);
}
}
@Data
private static class ConvertResponse {
private String version; // CHECKSTYLE IGNORE THIS LINE
private String data; // CHECKSTYLE IGNORE THIS LINE
private int code; // CHECKSTYLE IGNORE THIS LINE
private String message; // CHECKSTYLE IGNORE THIS LINE
public String toJson() {
return new Gson().toJson(this);
}
@Override
public String toString() {
return toJson();
}
}
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.parser;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.common.Config;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.DorisLexer;
import org.apache.doris.nereids.DorisParser;
@ -30,6 +31,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.qe.SessionVariable;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
@ -81,6 +83,10 @@ public class NereidsParser {
private List<StatementBase> parseSQLWithDialect(String sql,
@Nullable ParseDialect.Dialect sqlDialect,
SessionVariable sessionVariable) {
if (!Strings.isNullOrEmpty(Config.sql_convertor_service)) {
// if sql convertor service is enabled, no need to parse sql again by specific dialect.
return parseSQL(sql);
}
switch (sqlDialect) {
case TRINO:
final List<StatementBase> logicalPlans = TrinoParser.parse(sql, sessionVariable);

View File

@ -88,6 +88,10 @@ public enum ParseDialect {
* Trino parser dialect
*/
TRINO("trino"),
/**
* Presto parser dialect
*/
PRESTO("presto"),
/**
* Doris parser dialect
*/

View File

@ -35,6 +35,7 @@ import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.NotImplementedException;
import org.apache.doris.common.UserException;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.common.util.SQLDialectUtils;
import org.apache.doris.common.util.SqlParserUtils;
import org.apache.doris.common.util.SqlUtils;
import org.apache.doris.common.util.Util;
@ -169,7 +170,9 @@ public abstract class ConnectProcessor {
MetricRepo.COUNTER_REQUEST_ALL.increase(1L);
}
String sqlHash = DigestUtils.md5Hex(originStmt);
String convertedStmt = SQLDialectUtils.convertStmtWithDialect(originStmt, ctx, mysqlCommand);
String sqlHash = DigestUtils.md5Hex(convertedStmt);
ctx.setSqlHash(sqlHash);
ctx.getAuditEventBuilder().reset();
ctx.getAuditEventBuilder()
@ -183,25 +186,25 @@ public abstract class ConnectProcessor {
// Nereids do not support prepare and execute now, so forbid prepare command, only process query command
if (mysqlCommand == MysqlCommand.COM_QUERY && ctx.getSessionVariable().isEnableNereidsPlanner()) {
try {
stmts = new NereidsParser().parseSQL(originStmt, ctx.getSessionVariable());
stmts = new NereidsParser().parseSQL(convertedStmt, ctx.getSessionVariable());
} catch (NotSupportedException e) {
// Parse sql failed, audit it and return
handleQueryException(e, originStmt, null, null);
handleQueryException(e, convertedStmt, null, null);
return;
} catch (Exception e) {
// TODO: We should catch all exception here until we support all query syntax.
LOG.debug("Nereids parse sql failed. Reason: {}. Statement: \"{}\".",
e.getMessage(), originStmt);
e.getMessage(), convertedStmt);
}
}
// stmts == null when Nereids cannot planner this query or Nereids is disabled.
if (stmts == null) {
try {
stmts = parse(originStmt);
stmts = parse(convertedStmt);
} catch (Throwable throwable) {
// Parse sql failed, audit it and return
handleQueryException(throwable, originStmt, null, null);
handleQueryException(throwable, convertedStmt, null, null);
return;
}
}
@ -210,15 +213,15 @@ public abstract class ConnectProcessor {
// if stmts.size() > 1, split originStmt to multi singleStmts
if (stmts.size() > 1) {
try {
origSingleStmtList = SqlUtils.splitMultiStmts(originStmt);
origSingleStmtList = SqlUtils.splitMultiStmts(convertedStmt);
} catch (Exception ignore) {
LOG.warn("Try to parse multi origSingleStmt failed, originStmt: \"{}\"", originStmt);
LOG.warn("Try to parse multi origSingleStmt failed, originStmt: \"{}\"", convertedStmt);
}
}
boolean usingOrigSingleStmt = origSingleStmtList != null && origSingleStmtList.size() == stmts.size();
for (int i = 0; i < stmts.size(); ++i) {
String auditStmt = usingOrigSingleStmt ? origSingleStmtList.get(i) : originStmt;
String auditStmt = usingOrigSingleStmt ? origSingleStmtList.get(i) : convertedStmt;
ctx.getState().reset();
if (i > 0) {
@ -226,7 +229,7 @@ public abstract class ConnectProcessor {
}
StatementBase parsedStmt = stmts.get(i);
parsedStmt.setOrigStmt(new OriginStatement(originStmt, i));
parsedStmt.setOrigStmt(new OriginStatement(convertedStmt, i));
parsedStmt.setUserInfo(ctx.getCurrentUserIdentity());
executor = new StmtExecutor(ctx, parsedStmt);
ctx.setExecutor(executor);

View File

@ -0,0 +1,86 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.common.util;
import org.apache.doris.common.Config;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.utframe.SimpleHttpServer;
import org.apache.doris.utframe.TestWithFeService;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
public class SQLDialectUtilsTest {
int port;
SimpleHttpServer server;
@Before
public void setUp() throws Exception {
port = TestWithFeService.findValidPort();
server = new SimpleHttpServer(port);
server.start("/api/v1/convert");
}
@After
public void tearDown() {
if (server != null) {
server.stop();
}
}
@Test
public void testSqlConvert() throws IOException {
String originSql = "select * from t1 where \"k1\" = 1";
String expectedSql = "select * from t1 where `k1` = 1";
ConnectContext ctx = TestWithFeService.createDefaultCtx();
// 1. not COM_QUERY
String res = SQLDialectUtils.convertStmtWithDialect(originSql, ctx, MysqlCommand.COM_STMT_RESET);
Assert.assertEquals(originSql, res);
// 2. config sql_convertor_service not set
res = SQLDialectUtils.convertStmtWithDialect(originSql, ctx, MysqlCommand.COM_QUERY);
Assert.assertEquals(originSql, res);
// 3. session var sql_dialect not set
Config.sql_convertor_service = "http://127.0.0.1:" + port + "/api/v1/convert";
res = SQLDialectUtils.convertStmtWithDialect(originSql, ctx, MysqlCommand.COM_QUERY);
Assert.assertEquals(originSql, res);
// 4. not support dialect
ctx.getSessionVariable().setSqlDialect("sqlserver");
res = SQLDialectUtils.convertStmtWithDialect(originSql, ctx, MysqlCommand.COM_QUERY);
Assert.assertEquals(originSql, res);
// 5. test presto
ctx.getSessionVariable().setSqlDialect("presto");
server.setResponse("{\"version\": \"v1\", \"data\": \"" + expectedSql + "\", \"code\": 0, \"message\": \"\"}");
res = SQLDialectUtils.convertStmtWithDialect(originSql, ctx, MysqlCommand.COM_QUERY);
Assert.assertEquals(expectedSql, res);
// 6. test response version error
server.setResponse("{\"version\": \"v2\", \"data\": \"" + expectedSql + "\", \"code\": 0, \"message\": \"\"}");
res = SQLDialectUtils.convertStmtWithDialect(originSql, ctx, MysqlCommand.COM_QUERY);
Assert.assertEquals(originSql, res);
// 7. test response code error
server.setResponse(
"{\"version\": \"v1\", \"data\": \"" + expectedSql + "\", \"code\": 400, \"message\": \"\"}");
res = SQLDialectUtils.convertStmtWithDialect(originSql, ctx, MysqlCommand.COM_QUERY);
Assert.assertEquals(originSql, res);
}
}

View File

@ -0,0 +1,108 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.utframe;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpHandler;
import com.sun.net.httpserver.HttpServer;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
/**
* This a simple HttpServer for testing.
* It use internal JDK HttpServer, so it can't handle concurrent requests.
* It receive a POST request and return a response.
* The response can be set by {@link #setResponse(String)}.
*/
public class SimpleHttpServer {
private int port;
private HttpServer server;
private String response;
public SimpleHttpServer(int port) {
this.port = port;
}
public void setResponse(String response) {
this.response = response;
}
public String getResponse() {
return response;
}
public void start(String path) throws IOException {
server = HttpServer.create(new InetSocketAddress(port), 0);
server.createContext(path, new SqlHandler(this));
server.setExecutor(null);
server.start();
}
public void stop() {
if (server != null) {
server.stop(0);
}
}
private static class SqlHandler implements HttpHandler {
private SimpleHttpServer server;
public SqlHandler(SimpleHttpServer server) {
this.server = server;
}
@Override
public void handle(HttpExchange exchange) throws IOException {
if ("POST".equals(exchange.getRequestMethod())) {
InputStream requestBody = exchange.getRequestBody();
String body = new String(readAllBytes(requestBody), StandardCharsets.UTF_8);
System.out.println(body);
String responseText = server.getResponse();
exchange.sendResponseHeaders(200, responseText.getBytes().length);
OutputStream responseBody = exchange.getResponseBody();
responseBody.write(responseText.getBytes());
responseBody.close();
} else {
String responseText = "Unsupported method";
exchange.sendResponseHeaders(405, responseText.getBytes().length);
OutputStream responseBody = exchange.getResponseBody();
responseBody.write(responseText.getBytes());
responseBody.close();
}
}
}
private static byte[] readAllBytes(InputStream inputStream) throws IOException {
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
int nRead;
byte[] data = new byte[1024];
while ((nRead = inputStream.read(data, 0, data.length)) != -1) {
buffer.write(data, 0, nRead);
}
buffer.flush();
return buffer.toByteArray();
}
}

View File

@ -185,7 +185,7 @@ public abstract class TestWithFeService {
}
// Help to create a mocked ConnectContext.
protected ConnectContext createDefaultCtx() throws IOException {
public static ConnectContext createDefaultCtx() throws IOException {
return createCtx(UserIdentity.ROOT, "127.0.0.1");
}
@ -262,7 +262,7 @@ public abstract class TestWithFeService {
return adapter;
}
protected ConnectContext createCtx(UserIdentity user, String host) throws IOException {
protected static ConnectContext createCtx(UserIdentity user, String host) throws IOException {
ConnectContext ctx = new ConnectContext();
ctx.setCurrentUserIdentity(user);
ctx.setQualifiedUser(user.getQualifiedUser());