[ODBC/MySQL] Support Limit Clause Push Down For ODBC Table And MySQL Table(#4706) (#4707)

1. Support limit clause push down both odbc table and mysql table.
2. Code refactor of ODBC Scan Node, change `build_connect_string` and `query_string` from BE to FE to make it easily to modify
This commit is contained in:
HappenLee
2020-10-11 21:11:04 +08:00
committed by GitHub
parent fd5e3011da
commit d73d205de7
12 changed files with 111 additions and 122 deletions

View File

@ -116,7 +116,8 @@ Status MysqlScanNode::open(RuntimeState* state) {
RETURN_IF_CANCELLED(state);
SCOPED_TIMER(_runtime_profile->total_time_counter());
RETURN_IF_ERROR(_mysql_scanner->open());
RETURN_IF_ERROR(_mysql_scanner->query(_table_name, _columns, _filters));
RETURN_IF_ERROR(_mysql_scanner->query(_table_name, _columns, _filters, _limit));
// check materialize slot num
int materialize_num = 0;
@ -161,11 +162,6 @@ Status MysqlScanNode::get_next(RuntimeState* state, RowBatch* row_batch, bool* e
SCOPED_TIMER(_runtime_profile->total_time_counter());
SCOPED_TIMER(materialize_tuple_timer());
if (reached_limit()) {
*eos = true;
return Status::OK();
}
// create new tuple buffer for row_batch
int tuple_buffer_size = row_batch->capacity() * _tuple_desc->byte_size();
void* tuple_buffer = _tuple_pool->allocate(tuple_buffer_size);
@ -181,11 +177,10 @@ Status MysqlScanNode::get_next(RuntimeState* state, RowBatch* row_batch, bool* e
while (true) {
RETURN_IF_CANCELLED(state);
if (reached_limit() || row_batch->is_full()) {
if (row_batch->is_full()) {
// hang on to last allocated chunk in pool, we'll keep writing into it in the
// next get_next() call
row_batch->tuple_data_pool()->acquire_data(_tuple_pool.get(), !reached_limit());
*eos = reached_limit();
return Status::OK();
}

View File

@ -111,7 +111,7 @@ Status MysqlScanner::query(const std::string& query) {
}
Status MysqlScanner::query(const std::string& table, const std::vector<std::string>& fields,
const std::vector<std::string>& filters) {
const std::vector<std::string>& filters, const uint64_t limit) {
if (!_is_open) {
return Status::InternalError("Query before open.");
}
@ -140,6 +140,10 @@ Status MysqlScanner::query(const std::string& table, const std::vector<std::stri
}
}
if (limit != -1) {
_sql_str += " LIMIT " + std::to_string(limit);
}
return query(_sql_str);
}

View File

@ -56,7 +56,7 @@ public:
// query for DORIS
Status query(const std::string& table, const std::vector<std::string>& fields,
const std::vector<std::string>& filters);
const std::vector<std::string>& filters, const uint64_t limit);
Status get_next_row(char** *buf, unsigned long** lengths, bool* eos);
int field_num() const {

View File

@ -34,9 +34,9 @@ OdbcScanNode::OdbcScanNode(ObjectPool* pool, const TPlanNode& tnode,
: ScanNode(pool, tnode, descs),
_is_init(false),
_table_name(tnode.odbc_scan_node.table_name),
_connect_string(std::move(tnode.odbc_scan_node.connect_string)),
_query_string(std::move(tnode.odbc_scan_node.query_string)),
_tuple_id(tnode.odbc_scan_node.tuple_id),
_columns(tnode.odbc_scan_node.columns),
_filters(tnode.odbc_scan_node.filters),
_tuple_desc(nullptr) {
}
@ -63,21 +63,9 @@ Status OdbcScanNode::prepare(RuntimeState* state) {
}
_slot_num = _tuple_desc->slots().size();
// get odbc table info
const ODBCTableDescriptor* odbc_table =
static_cast<const ODBCTableDescriptor*>(_tuple_desc->table_desc());
if (NULL == odbc_table) {
return Status::InternalError("odbc table pointer is NULL.");
}
_odbc_param.host = odbc_table->host();
_odbc_param.port = odbc_table->port();
_odbc_param.user = odbc_table->user();
_odbc_param.passwd = odbc_table->passwd();
_odbc_param.db = odbc_table->db();
_odbc_param.drivier = odbc_table->driver();
_odbc_param.type = odbc_table->type();
_odbc_param.connect_string = std::move(_connect_string);
_odbc_param.query_string = std::move(_query_string);
_odbc_param.tuple_desc = _tuple_desc;
_odbc_scanner.reset(new (std::nothrow)ODBCScanner(_odbc_param));
@ -119,7 +107,7 @@ Status OdbcScanNode::open(RuntimeState* state) {
RETURN_IF_CANCELLED(state);
SCOPED_TIMER(_runtime_profile->total_time_counter());
RETURN_IF_ERROR(_odbc_scanner->open());
RETURN_IF_ERROR(_odbc_scanner->query(_table_name, _columns, _filters));
RETURN_IF_ERROR(_odbc_scanner->query());
// check materialize slot num
return Status::OK();
@ -153,11 +141,6 @@ Status OdbcScanNode::get_next(RuntimeState* state, RowBatch* row_batch, bool* eo
SCOPED_TIMER(_runtime_profile->total_time_counter());
SCOPED_TIMER(materialize_tuple_timer());
if (reached_limit()) {
*eos = true;
return Status::OK();
}
// create new tuple buffer for row_batch
int tuple_buffer_size = row_batch->capacity() * _tuple_desc->byte_size();
void* tuple_buffer = _tuple_pool->allocate(tuple_buffer_size);
@ -173,11 +156,10 @@ Status OdbcScanNode::get_next(RuntimeState* state, RowBatch* row_batch, bool* eo
while (true) {
RETURN_IF_CANCELLED(state);
if (reached_limit() || row_batch->is_full()) {
if (row_batch->is_full()) {
// hang on to last allocated chunk in pool, we'll keep writing into it in the
// next get_next() call
row_batch->tuple_data_pool()->acquire_data(_tuple_pool.get(), !reached_limit());
*eos = reached_limit();
return Status::OK();
}
@ -238,8 +220,6 @@ Status OdbcScanNode::get_next(RuntimeState* state, RowBatch* row_batch, bool* eo
_tuple = reinterpret_cast<Tuple*>(new_tuple);
}
}
return Status::OK();
}
Status OdbcScanNode::close(RuntimeState* state) {

View File

@ -68,13 +68,12 @@ private:
// Name of Odbc table
std::string _table_name;
std::string _connect_string;
std::string _query_string;
// Tuple id resolved in prepare() to set _tuple_desc;
TupleId _tuple_id;
// select columns
std::vector<std::string> _columns;
// where clause
std::vector<std::string> _filters;
// Descriptor of tuples read from ODBC table.
const TupleDescriptor* _tuple_desc;

View File

@ -47,8 +47,8 @@ static std::u16string utf8_to_wstring(const std::string& str) {
namespace doris {
ODBCScanner::ODBCScanner(const ODBCScannerParam& param)
: _connect_string(build_connect_string(param)),
_type(param.type),
: _connect_string(param.connect_string),
_sql_str(param.query_string),
_tuple_desc(param.tuple_desc),
_is_open(false),
_field_num(0),
@ -97,7 +97,7 @@ Status ODBCScanner::open() {
return Status::OK();
}
Status ODBCScanner::query(const std::string& query) {
Status ODBCScanner::query() {
if (!_is_open) {
return Status::InternalError( "Query before open.");
}
@ -106,13 +106,13 @@ Status ODBCScanner::query(const std::string& query) {
ODBC_DISPOSE(_dbc, SQL_HANDLE_DBC, SQLAllocHandle(SQL_HANDLE_STMT, _dbc, &_stmt), "alloc statement");
// Translate utf8 string to utf16 to use unicode codeing
auto wquery = utf8_to_wstring(query);
auto wquery = utf8_to_wstring(_sql_str);
ODBC_DISPOSE(_stmt, SQL_HANDLE_STMT, SQLExecDirectW(_stmt, (SQLWCHAR*)(wquery.c_str()), SQL_NTS), "exec direct");
// How many columns are there */
ODBC_DISPOSE(_stmt, SQL_HANDLE_STMT, SQLNumResultCols(_stmt, &_field_num), "count num colomn");
LOG(INFO) << "execute success:" << query << " column count:" << _field_num;
LOG(INFO) << "execute success:" << _sql_str << " column count:" << _field_num;
// check materialize num equal _field_num
int materialize_num = 0;
@ -145,39 +145,6 @@ Status ODBCScanner::query(const std::string& query) {
return Status::OK();
}
Status ODBCScanner::query(const std::string& table, const std::vector<std::string>& fields,
const std::vector<std::string>& filters) {
if (!_is_open) {
return Status::InternalError("Query before open.");
}
_sql_str = "SELECT ";
for (int i = 0; i < fields.size(); ++i) {
if (0 != i) {
_sql_str += ",";
}
_sql_str += fields[i];
}
_sql_str += " FROM " + table;
if (!filters.empty()) {
_sql_str += " WHERE ";
for (int i = 0; i < filters.size(); ++i) {
if (0 != i) {
_sql_str += " AND";
}
_sql_str += " (" + filters[i] + ") ";
}
}
return query(_sql_str);
}
Status ODBCScanner::get_next_row(bool* eos) {
if (!_is_open) {
return Status::InternalError("GetNextRow before open.");
@ -240,23 +207,4 @@ std::string ODBCScanner::handle_diagnostic_record(SQLHANDLE hHandle,
return diagnostic_msg;
}
std::string ODBCScanner::build_connect_string(const ODBCScannerParam& param) {
// different database have different connection string
// oracle connect string
if (param.type == TOdbcTableType::ORACLE) {
boost::format connect_string("Driver=%s;Dbq=//%s:%s/%s;DataBase=%s;Uid=%s;Pwd=%s;charset=%s");
connect_string % param.drivier % param.host % param.port % param.db % param.db % param.user % param.passwd %
param.charest;
return connect_string.str();
} else if (param.type == TOdbcTableType::MYSQL) {
boost::format connect_string("Driver=%s;Server=%s;Port=%s;DataBase=%s;Uid=%s;Pwd=%s;charset=%s");
connect_string % param.drivier % param.host % param.port % param.db % param.user % param.passwd %
param.charest;
return connect_string.str();
}
return "";
}
}

View File

@ -32,15 +32,9 @@
namespace doris {
struct ODBCScannerParam {
std::string host;
std::string port;
std::string user;
std::string passwd;
std::string db;
std::string drivier;
std::string charest = "utf8";
std::string connect_string;
std::string query_string;
TOdbcTableType::type type;
const TupleDescriptor* tuple_desc;
};
@ -67,11 +61,8 @@ public:
Status open();
Status query(const std::string& query);
// query for DORIS
Status query(const std::string& table, const std::vector<std::string>& fields,
const std::vector<std::string>& filters);
// query for ODBC table
Status query();
Status get_next_row(bool* eos);
@ -80,8 +71,6 @@ public:
}
private:
static std::string build_connect_string(const ODBCScannerParam& param);
static Status error_status(const std::string& prefix, const std::string& error_msg);
static std::string handle_diagnostic_record (SQLHANDLE hHandle,
@ -90,7 +79,6 @@ private:
std::string _connect_string;
std::string _sql_str;
TOdbcTableType::type _type;
const TupleDescriptor* _tuple_desc;
bool _is_open;

View File

@ -59,8 +59,6 @@ public class OdbcTable extends Table {
static {
Map<String, TOdbcTableType> tempMap = new HashMap<>();
tempMap.put("oracle", TOdbcTableType.ORACLE);
// we will support mysql driver in the future after we solve the core problem of
// driver and static library
tempMap.put("mysql", TOdbcTableType.MYSQL);
TABLE_TYPE_MAP = Collections.unmodifiableMap(tempMap);
}
@ -244,6 +242,36 @@ public class OdbcTable extends Table {
return getPropertyFromResource(ODBC_TYPE);
}
public String getConnectString() {
String connectString = "";
// different database have different connection string
switch (getOdbcTableType()) {
case ORACLE:
connectString = String.format("Driver=%s;Dbq=//%s:%s/%s;DataBase=%s;Uid=%s;Pwd=%s;charset=%s",
getOdbcDriver(),
getHost(),
getPort(),
getOdbcDatabaseName(),
getOdbcDatabaseName(),
getUserName(),
getPasswd(),
"utf8");
break;
case MYSQL:
connectString = String.format("Driver=%s;Server=%s;Port=%s;DataBase=%s;Uid=%s;Pwd=%s;charset=%s",
getOdbcDriver(),
getHost(),
getPort(),
getOdbcDatabaseName(),
getUserName(),
getPasswd(),
"utf8");
break;
default:
}
return connectString;
}
public TOdbcTableType getOdbcTableType() {
return TABLE_TYPE_MAP.get(getOdbcTableTypeName());
}

View File

@ -92,6 +92,11 @@ public class MysqlScanNode extends ScanNode {
sql.append(Joiner.on(") AND (").join(filters));
sql.append(")");
}
if (limit != -1) {
sql.append(" LIMIT " + limit);
}
return sql.toString();
}

View File

@ -78,7 +78,7 @@ public class OdbcScanNode extends ScanNode {
private final List<String> columns = new ArrayList<String>();
private final List<String> filters = new ArrayList<String>();
private String tblName;
private String driver;
private String connectString;
private TOdbcTableType odbcType;
/**
@ -86,7 +86,7 @@ public class OdbcScanNode extends ScanNode {
*/
public OdbcScanNode(PlanNodeId id, TupleDescriptor desc, OdbcTable tbl) {
super(id, desc, "SCAN ODBC");
driver = tbl.getOdbcDriver();
connectString = tbl.getConnectString();
odbcType = tbl.getOdbcTableType();
tblName = databaseProperName(odbcType, tbl.getOdbcTableName());
}
@ -109,12 +109,24 @@ public class OdbcScanNode extends ScanNode {
protected String getNodeExplainString(String prefix, TExplainLevel detailLevel) {
StringBuilder output = new StringBuilder();
output.append(prefix).append("TABLE: ").append(tblName).append("\n");
output.append(prefix).append("Query: ").append(getOdbcQueryStr()).append("\n");
output.append(prefix).append("TABLE TYPE: ").append(odbcType.toString()).append("\n");
output.append(prefix).append("QUERY: ").append(getOdbcQueryStr()).append("\n");
return output.toString();
}
private String getOdbcQueryStr() {
StringBuilder sql = new StringBuilder("SELECT ");
// Oracle use the where clause to do top n
if (limit != -1 && odbcType == TOdbcTableType.ORACLE) {
filters.add("ROWNUM <= " + limit);
}
// MSSQL use select top to do top n
if (limit != -1 && odbcType == TOdbcTableType.SQLSERVER) {
sql.append("TOP " + limit + " ");
}
sql.append(Joiner.on(", ").join(columns));
sql.append(" FROM ").append(tblName);
@ -123,6 +135,12 @@ public class OdbcScanNode extends ScanNode {
sql.append(Joiner.on(") AND (").join(filters));
sql.append(")");
}
// Other DataBase use limit do top n
if (limit != -1 && (odbcType == TOdbcTableType.MYSQL || odbcType == TOdbcTableType.POSTGRESQL || odbcType == TOdbcTableType.MONGODB) ) {
sql.append(" LIMIT " + limit);
}
return sql.toString();
}
@ -172,10 +190,8 @@ public class OdbcScanNode extends ScanNode {
TOdbcScanNode odbcScanNode = new TOdbcScanNode();
odbcScanNode.setTupleId(desc.getId().asInt());
odbcScanNode.setTableName(tblName);
odbcScanNode.setDriver(driver);
odbcScanNode.setType(odbcType);
odbcScanNode.setColumns(columns);
odbcScanNode.setFilters(filters);
odbcScanNode.setConnectString(connectString);
odbcScanNode.setQueryString(getOdbcQueryStr());
msg.odbc_scan_node = odbcScanNode;
}

View File

@ -1214,6 +1214,26 @@ public class QueryPlanTest {
Assert.assertTrue(!explainString.contains("abs(k1) > 10"));
}
@Test
public void testLimitOfExternalTable() throws Exception {
connectContext.setDatabase("default_cluster:test");
// ODBC table (MySQL)
String queryStr = "explain select * from odbc_mysql where k1 > 10 and abs(k1) > 10 limit 10";
String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, queryStr);
Assert.assertTrue(explainString.contains("LIMIT 10"));
// ODBC table (Oracle)
queryStr = "explain select * from odbc_oracle where k1 > 10 and abs(k1) > 10 limit 10";
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, queryStr);
Assert.assertTrue(explainString.contains("ROWNUM <= 10"));
// MySQL table
queryStr = "explain select * from mysql_table where k1 > 10 and abs(k1) > 10 limit 10";
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, queryStr);
Assert.assertTrue(explainString.contains("LIMIT 10"));
}
@Test
public void testPreferBroadcastJoin() throws Exception {
connectContext.setDatabase("default_cluster:test");

View File

@ -202,10 +202,16 @@ struct TMySQLScanNode {
struct TOdbcScanNode {
1: optional Types.TTupleId tuple_id
2: optional string table_name
//Deprecated
3: optional string driver
4: optional Types.TOdbcTableType type
5: optional list<string> columns
6: optional list<string> filters
//Use now
7: optional string connect_string
8: optional string query_string
}