Improve the processing logic of Load statement derived columns (#5140)

* support transitive in load expr
This commit is contained in:
Zhengguo Yang
2020-12-30 10:27:46 +08:00
committed by GitHub
parent cd865c95e0
commit 62604dfeac
9 changed files with 108 additions and 29 deletions

View File

@ -51,6 +51,10 @@ public class ImportColumnDesc {
return expr;
}
public void setExpr(Expr expr) {
this.expr = expr;
}
public boolean isColumn() {
return expr == null;
}

View File

@ -107,6 +107,7 @@ import org.apache.doris.transaction.TransactionState.LoadJobSourceType;
import org.apache.doris.transaction.TransactionState.TxnCoordinator;
import org.apache.doris.transaction.TransactionState.TxnSourceType;
import org.apache.doris.transaction.TransactionStatus;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
@ -115,6 +116,7 @@ import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.gson.Gson;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -941,6 +943,7 @@ public class Load {
Map<String, Pair<String, List<String>>> columnToHadoopFunction,
Map<String, Expr> exprsByName, Analyzer analyzer, TupleDescriptor srcTupleDesc,
Map<String, SlotDescriptor> slotDescByName, TBrokerScanRangeParams params) throws UserException {
rewriteColumns(columnExprs);
initColumns(tbl, columnExprs, columnToHadoopFunction, exprsByName, analyzer,
srcTupleDesc, slotDescByName, params, true);
}
@ -958,22 +961,16 @@ public class Load {
Map<String, Expr> exprsByName, Analyzer analyzer, TupleDescriptor srcTupleDesc,
Map<String, SlotDescriptor> slotDescByName, TBrokerScanRangeParams params,
boolean needInitSlotAndAnalyzeExprs) throws UserException {
// check mapping column exist in schema
// !! all column mappings are in columnExprs !!
for (ImportColumnDesc importColumnDesc : columnExprs) {
if (importColumnDesc.isColumn()) {
continue;
}
String mappingColumnName = importColumnDesc.getColumnName();
if (tbl.getColumn(mappingColumnName) == null) {
throw new DdlException("Mapping column is not in table. column: " + mappingColumnName);
}
}
// We make a copy of the columnExprs so that our subsequent changes
// to the columnExprs will not affect the original columnExprs.
List<ImportColumnDesc> copiedColumnExprs = Lists.newArrayList(columnExprs);
// skip the mapping columns not exist in schema
List<ImportColumnDesc> copiedColumnExprs = new ArrayList<>();
for (ImportColumnDesc importColumnDesc : columnExprs) {
String mappingColumnName = importColumnDesc.getColumnName();
if (importColumnDesc.isColumn() || tbl.getColumn(mappingColumnName) != null) {
copiedColumnExprs.add(importColumnDesc);
}
}
// check whether the OlapTable has sequenceCol
boolean hasSequenceCol = false;
if (tbl instanceof OlapTable && ((OlapTable)tbl).hasSequenceCol()) {
@ -1133,6 +1130,44 @@ public class Load {
LOG.debug("after init column, exprMap: {}", exprsByName);
}
public static void rewriteColumns(List<ImportColumnDesc> columnExprs) {
Map<String, Expr> derivativeColumns = new HashMap<>();
// find and rewrite the derivative columns
// e.g. (v1,v2=v1+1,v3=v2+1) --> (v1, v2=v1+1, v3=v1+1+1)
// 1. find the derivative columns
for (ImportColumnDesc importColumnDesc : columnExprs) {
if (!importColumnDesc.isColumn()) {
if (importColumnDesc.getExpr() instanceof SlotRef) {
String columnName = ((SlotRef) importColumnDesc.getExpr()).getColumnName();
if (derivativeColumns.containsKey(columnName)) {
importColumnDesc.setExpr(derivativeColumns.get(columnName));
}
} else {
recursiveRewrite(importColumnDesc.getExpr(), derivativeColumns);
}
derivativeColumns.put(importColumnDesc.getColumnName(), importColumnDesc.getExpr());
}
}
}
private static void recursiveRewrite(Expr expr, Map<String, Expr> derivativeColumns) {
if (CollectionUtils.isEmpty(expr.getChildren())) {
return;
}
for (int i = 0; i < expr.getChildren().size(); i++) {
Expr e = expr.getChild(i);
if (e instanceof SlotRef) {
String columnName = ((SlotRef) e).getColumnName();
if (derivativeColumns.containsKey(columnName)) {
expr.setChild(i, derivativeColumns.get(columnName));
}
} else {
recursiveRewrite(e, derivativeColumns);
}
}
}
/**
* This method is used to transform hadoop function.
* The hadoop function includes: replace_value, strftime, time_format, alignment_timestamp, default_value, now.

View File

@ -23,7 +23,9 @@ import org.apache.doris.catalog.SparkResource;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.Config;
import org.apache.doris.common.UserException;
import org.apache.doris.common.util.SqlParserUtils;
import org.apache.doris.load.EtlJobType;
import org.apache.doris.load.Load;
import org.apache.doris.load.loadv2.LoadTask;
import org.apache.doris.mysql.privilege.PaloAuth;
import org.apache.doris.mysql.privilege.PrivPredicate;
@ -35,7 +37,11 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import mockit.Expectations;
import mockit.Injectable;
@ -134,4 +140,31 @@ public class LoadStmtTest {
Assert.fail("No exception throws.");
}
@Test
public void testRewrite() throws Exception{
List<ImportColumnDesc> columns1 = getColumns("c1,c2,c3,tmp_c4=c1 + 1, tmp_c5 = tmp_c4+1");
Load.rewriteColumns(columns1);
String orig = "`c1` + 1 + 1";
Assert.assertEquals(orig, columns1.get(4).getExpr().toString());
List<ImportColumnDesc> columns2 = getColumns("c1,c2,c3,tmp_c5 = tmp_c4+1, tmp_c4=c1 + 1");
String orig2 = "`tmp_c4` + 1";
Load.rewriteColumns(columns2);
Assert.assertEquals(orig2, columns2.get(3).getExpr().toString());
List<ImportColumnDesc> columns3 = getColumns("c1,c2,c3");
String orig3 = "c3";
Load.rewriteColumns(columns3);
Assert.assertEquals(orig3, columns3.get(2).toString());
}
private List<ImportColumnDesc> getColumns(String columns) throws Exception {
String columnsSQL = "COLUMNS (" + columns + ")";
return ((ImportColumnsStmt) SqlParserUtils.getFirstStmt(
new SqlParser(
new SqlScanner(
new StringReader(columnsSQL))))).getColumns();
}
}

View File

@ -674,17 +674,6 @@ public class StreamLoadScanNodeTest {
scanNode.toThrift(planNode);
}
@Test(expected = DdlException.class)
public void testLoadInitColumnsMappingColumnNotExist() throws UserException {
List<Column> columns = Lists.newArrayList();
columns.add(new Column("c1", Type.INT, true, null, false, null, ""));
columns.add(new Column("c2", ScalarType.createVarchar(10), true, null, false, null, ""));
Table table = new Table(1L, "table0", TableType.OLAP, columns);
List<ImportColumnDesc> columnExprs = Lists.newArrayList();
columnExprs.add(new ImportColumnDesc("c3", new FunctionCallExpr("func", Lists.newArrayList())));
Load.initColumns(table, columnExprs, null, null, null, null, null, null);
}
@Test
public void testSequenceColumnWithSetColumns() throws UserException {
Analyzer analyzer = new Analyzer(catalog, connectContext);