[improvement](jdbc) support insert autoinc and default value column to mysql (#20765)

In JdbcMysqlClient, I've added methods to retrieve auto-increment and default value columns from MySQL. These columns are then mapped into Doris metadata to make them visible to users.

When handling the InsertStmt into an execution plan, Doris used to automatically fill in NULL or default values for columns not specified in the InsertStmt. However, in the JDBC catalog, we don't need Doris to handle these unspecified columns, so I've made changes to skip them directly.

For the insert prepared statement required for writing, our previous behavior was to obtain all columns for placeholders. So, the change I made is to pass in the columns processed by the execution plan during the sink task generation stage for dynamic generation.
This commit is contained in:
zy-kkk
2023-06-16 23:38:11 +08:00
committed by GitHub
parent e834637a5b
commit 367f64e7bd
9 changed files with 101 additions and 11 deletions

View File

@ -283,3 +283,10 @@ create table doris_test.all_types (
`varbinary` varbinary(12),
`enum` enum('Value1', 'Value2', 'Value3')
) engine=innodb charset=utf8;
CREATE TABLE `doris_test`.`auto_default_t` (
`id` bigint NOT NULL AUTO_INCREMENT,
`name` varchar(64) DEFAULT NULL,
`dt` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (`id`)
) engine=innodb charset=utf8;

View File

@ -47,6 +47,7 @@ import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.planner.DataPartition;
import org.apache.doris.planner.DataSink;
import org.apache.doris.planner.ExportSink;
import org.apache.doris.planner.JdbcTableSink;
import org.apache.doris.planner.OlapTableSink;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.rewrite.ExprRewriter;
@ -392,6 +393,9 @@ public class NativeInsertStmt extends InsertStmt {
// check columns of target table
for (Column col : baseColumns) {
if (col.isAutoInc()) {
continue;
}
if (isPartialUpdate && !partialUpdateCols.contains(col.getName())) {
continue;
}
@ -720,6 +724,9 @@ public class NativeInsertStmt extends InsertStmt {
}
if (exprByName.containsKey(col.getName())) {
resultExprByName.add(Pair.of(col.getName(), exprByName.get(col.getName())));
} else if (targetTable.getType().equals(TableIf.TableType.JDBC_EXTERNAL_TABLE)) {
// For JdbcTable,we do not need to generate plans for columns that are not specified at write time
continue;
} else {
// process sequence col, map sequence column to other column
if (targetTable instanceof OlapTable && ((OlapTable) targetTable).hasSequenceCol()
@ -771,6 +778,15 @@ public class NativeInsertStmt extends InsertStmt {
table.getLineDelimiter(),
brokerDesc);
dataPartition = dataSink.getOutputPartition();
} else if (targetTable instanceof JdbcTable) {
//for JdbcTable,we need to pass the currently written column to `JdbcTableSink`
//to generate the prepare insert statment
List<String> insertCols = Lists.newArrayList();
for (Column column : targetColumns) {
insertCols.add(column.getName());
}
dataSink = new JdbcTableSink((JdbcTable) targetTable, insertCols);
dataPartition = DataPartition.UNPARTITIONED;
} else {
dataSink = DataSink.createDataSink(targetTable);
dataPartition = DataPartition.UNPARTITIONED;

View File

@ -97,11 +97,15 @@ public class JdbcTable extends Table {
super(id, name, type, schema);
}
public String getInsertSql() {
public String getInsertSql(List<String> insertCols) {
StringBuilder sb = new StringBuilder("INSERT INTO ");
sb.append(databaseProperName(TABLE_TYPE_MAP.get(getTableTypeName()), getExternalTableName()));
sb.append("(");
sb.append(String.join(",", insertCols));
sb.append(")");
sb.append(" VALUES (");
for (int i = 0; i < getFullSchema().size(); ++i) {
for (int i = 0; i < insertCols.size(); ++i) {
if (i != 0) {
sb.append(", ");
}

View File

@ -292,6 +292,7 @@ public abstract class JdbcClient {
String catalogName = getCatalogName(conn);
tableName = modifyTableNameIfNecessary(tableName);
rs = getColumns(databaseMetaData, catalogName, dbName, tableName);
List<String> primaryKeys = getPrimaryKeys(dbName, tableName);
while (rs.next()) {
if (isTableModified(tableName, rs.getString("TABLE_NAME"))) {
continue;
@ -300,6 +301,7 @@ public abstract class JdbcClient {
field.setColumnName(rs.getString("COLUMN_NAME"));
field.setDataType(rs.getInt("DATA_TYPE"));
field.setDataTypeName(rs.getString("TYPE_NAME"));
field.setKey(primaryKeys.contains(field.getColumnName()));
field.setColumnSize(rs.getInt("COLUMN_SIZE"));
field.setDecimalDigits(rs.getInt("DECIMAL_DIGITS"));
field.setNumPrecRadix(rs.getInt("NUM_PREC_RADIX"));
@ -328,7 +330,7 @@ public abstract class JdbcClient {
List<Column> dorisTableSchema = Lists.newArrayListWithCapacity(jdbcTableSchema.size());
for (JdbcFieldSchema field : jdbcTableSchema) {
dorisTableSchema.add(new Column(field.getColumnName(),
jdbcTypeToDoris(field), true, null,
jdbcTypeToDoris(field), field.isKey, null,
field.isAllowNull(), field.getRemarks(),
true, -1));
}
@ -387,6 +389,19 @@ public abstract class JdbcClient {
return databaseMetaData.getColumns(catalogName, schemaName, tableName, null);
}
/**
* We used this method to retrieve the key column of the JDBC table, but since we only tested mysql,
* we kept the default key behavior in the parent class and only overwrite it in the mysql subclass
*/
protected List<String> getPrimaryKeys(String dbName, String tableName) {
List<String> primaryKeys = Lists.newArrayList();
List<JdbcFieldSchema> columns = getJdbcColumnsInfo(dbName, tableName);
for (JdbcFieldSchema column : columns) {
primaryKeys.add(column.getColumnName());
}
return primaryKeys;
}
@Data
protected static class JdbcFieldSchema {
protected String columnName;
@ -394,6 +409,7 @@ public abstract class JdbcClient {
protected int dataType;
// The SQL type of the corresponding java.sql.types (Type Name)
protected String dataTypeName;
protected boolean isKey;
// For CHAR/DATA, columnSize means the maximum number of chars.
// For NUMERIC/DECIMAL, columnSize means precision.
protected int columnSize;
@ -407,6 +423,8 @@ public abstract class JdbcClient {
// because for utf8 encoding, a Chinese character takes up 3 bytes
protected int charOctetLength;
protected boolean isAllowNull;
protected boolean isAutoincrement;
protected String defaultValue;
}
protected abstract Type jdbcTypeToDoris(JdbcFieldSchema fieldSchema);

View File

@ -17,19 +17,20 @@
package org.apache.doris.external.jdbc;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.util.Util;
import avro.shaded.com.google.common.collect.Lists;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
@ -91,7 +92,7 @@ public class JdbcMySQLClient extends JdbcClient {
private Map<String, String> getJdbcColumnsTypeInfo(String dbName, String tableName) {
Connection conn = getConnection();
ResultSet resultSet = null;
Map<String, String> fieldtoType = new HashMap<String, String>();
Map<String, String> fieldtoType = Maps.newHashMap();
StringBuilder queryBuf = new StringBuilder("SHOW FULL COLUMNS FROM ");
queryBuf.append(tableName);
@ -135,6 +136,7 @@ public class JdbcMySQLClient extends JdbcClient {
String catalogName = getCatalogName(conn);
tableName = modifyTableNameIfNecessary(tableName);
rs = getColumns(databaseMetaData, catalogName, dbName, tableName);
List<String> primaryKeys = getPrimaryKeys(dbName, tableName);
boolean needGetDorisColumns = true;
Map<String, String> mapFieldtoType = null;
while (rs.next()) {
@ -159,6 +161,7 @@ public class JdbcMySQLClient extends JdbcClient {
}
}
field.setKey(primaryKeys.contains(field.getColumnName()));
field.setColumnSize(rs.getInt("COLUMN_SIZE"));
field.setDecimalDigits(rs.getInt("DECIMAL_DIGITS"));
field.setNumPrecRadix(rs.getInt("NUM_PREC_RADIX"));
@ -171,6 +174,9 @@ public class JdbcMySQLClient extends JdbcClient {
field.setAllowNull(rs.getInt("NULLABLE") != 0);
field.setRemarks(rs.getString("REMARKS"));
field.setCharOctetLength(rs.getInt("CHAR_OCTET_LENGTH"));
String isAutoincrement = rs.getString("IS_AUTOINCREMENT");
field.setAutoincrement("YES".equalsIgnoreCase(isAutoincrement));
field.setDefaultValue(rs.getString("COLUMN_DEF"));
tableSchema.add(field);
}
} catch (SQLException e) {
@ -182,6 +188,41 @@ public class JdbcMySQLClient extends JdbcClient {
return tableSchema;
}
@Override
public List<Column> getColumnsFromJdbc(String dbName, String tableName) {
List<JdbcFieldSchema> jdbcTableSchema = getJdbcColumnsInfo(dbName, tableName);
List<Column> dorisTableSchema = Lists.newArrayListWithCapacity(jdbcTableSchema.size());
for (JdbcFieldSchema field : jdbcTableSchema) {
dorisTableSchema.add(new Column(field.getColumnName(),
jdbcTypeToDoris(field), field.isKey(), null,
field.isAllowNull(), field.isAutoincrement(), field.getDefaultValue(), field.getRemarks(),
true, null, -1, null,
null, null, null));
}
return dorisTableSchema;
}
@Override
protected List<String> getPrimaryKeys(String dbName, String tableName) {
List<String> primaryKeys = Lists.newArrayList();
Connection conn = null;
ResultSet rs = null;
try {
conn = getConnection();
DatabaseMetaData databaseMetaData = conn.getMetaData();
rs = databaseMetaData.getPrimaryKeys(dbName, null, tableName);
while (rs.next()) {
String columnName = rs.getString("COLUMN_NAME");
primaryKeys.add(columnName);
}
} catch (SQLException e) {
throw new JdbcClientException("Failed to get primary keys for table", e);
} finally {
close(rs, conn);
}
return primaryKeys;
}
@Override
protected Type jdbcTypeToDoris(JdbcFieldSchema fieldSchema) {
// For mysql type: "INT UNSIGNED":

View File

@ -20,7 +20,6 @@
package org.apache.doris.planner;
import org.apache.doris.catalog.JdbcTable;
import org.apache.doris.catalog.MysqlTable;
import org.apache.doris.catalog.OdbcTable;
import org.apache.doris.catalog.Table;
@ -66,8 +65,6 @@ public abstract class DataSink {
return new MysqlTableSink((MysqlTable) table);
} else if (table instanceof OdbcTable) {
return new OdbcTableSink((OdbcTable) table);
} else if (table instanceof JdbcTable) {
return new JdbcTableSink((JdbcTable) table);
} else {
throw new AnalysisException("Unknown table type " + table.getType());
}

View File

@ -29,6 +29,8 @@ import org.apache.doris.thrift.TOdbcTableType;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.List;
public class JdbcTableSink extends DataSink {
private static final Logger LOG = LogManager.getLogger(JdbcTableSink.class);
@ -45,7 +47,7 @@ public class JdbcTableSink extends DataSink {
private final boolean useTransaction;
private String insertSql;
public JdbcTableSink(JdbcTable jdbcTable) {
public JdbcTableSink(JdbcTable jdbcTable, List<String> insertCols) {
resourceName = jdbcTable.getResourceName();
jdbcType = jdbcTable.getJdbcTableType();
externalTableName = JdbcTable.databaseProperName(jdbcType, jdbcTable.getExternalTableName());
@ -57,7 +59,7 @@ public class JdbcTableSink extends DataSink {
driverUrl = jdbcTable.getDriverUrl();
checkSum = jdbcTable.getCheckSum();
dorisTableName = jdbcTable.getName();
insertSql = jdbcTable.getInsertSql();
insertSql = jdbcTable.getInsertSql(insertCols);
}
@Override

View File

@ -246,6 +246,9 @@ VIEWS
VIEW_ROUTINE_USAGE
VIEW_TABLE_USAGE
-- !auto_default_t --
0
-- !test_insert1 --
doris1 18

View File

@ -48,6 +48,7 @@ suite("test_mysql_jdbc_catalog", "p0") {
String ex_tb20 = "ex_tb20";
String test_insert = "test_insert";
String test_insert2 = "test_insert2";
String auto_default_t = "auto_default_t";
sql """drop catalog if exists ${catalog_name} """
@ -99,6 +100,7 @@ suite("test_mysql_jdbc_catalog", "p0") {
order_qt_ex_tb19 """ select * from ${ex_tb19} order by date_value; """
order_qt_ex_tb20 """ select * from ${ex_tb20} order by decimal_normal; """
order_qt_information_schema """ show tables from information_schema; """
order_qt_auto_default_t """insert into ${auto_default_t}(name) values('a'); """
// test insert
String uuid1 = UUID.randomUUID().toString();