[Bug](agg-state) fix core dump on not nullable argument for aggstate's nested argument (#21331)

fix core dump on not nullable argument for aggstate's nested argument
This commit is contained in:
Pxl
2023-06-30 18:20:25 +08:00
committed by GitHub
parent b7d6a70868
commit 88cbea2b56
17 changed files with 115 additions and 46 deletions

View File

@ -260,7 +260,7 @@ public class ArithmeticExpr extends Expr {
if (children.size() == 1) {
return op.toString() + " " + getChild(0).toSql();
} else {
return getChild(0).toSql() + " " + op.toString() + " " + getChild(1).toSql();
return "(" + getChild(0).toSql() + " " + op.toString() + " " + getChild(1).toSql() + ")";
}
}

View File

@ -437,9 +437,7 @@ public class NativeInsertStmt extends InsertStmt {
mentionedColumns.add(col.getName());
targetColumns.add(col);
}
realTargetColumnNames = targetColumns.stream().map(column -> column.getName()).collect(Collectors.toList());
} else {
realTargetColumnNames = targetColumnNames;
for (String colName : targetColumnNames) {
Column col = targetTable.getColumn(colName);
if (col == null) {
@ -453,8 +451,8 @@ public class NativeInsertStmt extends InsertStmt {
// hll column mush in mentionedColumns
for (Column col : targetTable.getBaseSchema()) {
if (col.getType().isObjectStored() && !mentionedColumns.contains(col.getName())) {
throw new AnalysisException(" object-stored column " + col.getName()
+ " must in insert into columns");
throw new AnalysisException(
" object-stored column " + col.getName() + " mush in insert into columns");
}
}
}
@ -535,20 +533,30 @@ public class NativeInsertStmt extends InsertStmt {
}
// check if size of select item equal with columns mentioned in statement
if (mentionedColumns.size() != queryStmt.getResultExprs().size()
|| realTargetColumnNames.size() != queryStmt.getResultExprs().size()) {
if (mentionedColumns.size() != queryStmt.getResultExprs().size()) {
ErrorReport.reportAnalysisException(ErrorCode.ERR_WRONG_VALUE_COUNT);
}
// Check if all columns mentioned is enough
checkColumnCoverage(mentionedColumns, targetTable.getBaseSchema());
realTargetColumnNames = targetColumns.stream().map(column -> column.getName()).collect(Collectors.toList());
Map<String, Expr> slotToIndex = Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER);
for (int i = 0; i < realTargetColumnNames.size(); i++) {
for (int i = 0; i < queryStmt.getResultExprs().size(); i++) {
slotToIndex.put(realTargetColumnNames.get(i), queryStmt.getResultExprs().get(i)
.checkTypeCompatibility(targetTable.getColumn(realTargetColumnNames.get(i)).getType()));
}
for (Column column : targetTable.getBaseSchema()) {
if (!slotToIndex.containsKey(column.getName())) {
if (column.getDefaultValue() == null) {
slotToIndex.put(column.getName(), new NullLiteral());
} else {
slotToIndex.put(column.getName(), new StringLiteral(column.getDefaultValue()));
}
}
}
// handle VALUES() or SELECT constant list
if (isValuesOrConstantSelect) {
SelectStmt selectStmt = (SelectStmt) queryStmt;

View File

@ -193,9 +193,6 @@ public class MaterializedIndexMeta implements Writable, GsonPostProcessable {
// mv_count_sale_amt -> mva_SUM__CASE WHEN `sale_amt` IS NULL THEN 0 ELSE 1 END
List<SlotRef> slots = new ArrayList<>();
entry.getValue().collect(SlotRef.class, slots);
if (slots.size() > 1) {
throw new IOException("DefineExpr have multiple slot in MaterializedIndex, Expr=" + entry.getKey());
}
String name = MaterializedIndexMeta.normalizeName(slots.get(0).toSqlWithoutTbl());
Column matchedColumn = null;

View File

@ -18,15 +18,12 @@
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
import org.apache.doris.nereids.types.AggStateType;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@ -66,9 +63,7 @@ public class AggStateFunctionBuilder extends FunctionBuilder {
return false;
}
return nestedBuilder.canApply(((AggStateType) argument.getDataType()).getSubTypes().stream().map(t -> {
return new SlotReference("mocked", t);
}).collect(ImmutableList.toImmutableList()));
return nestedBuilder.canApply(((AggStateType) argument.getDataType()).getMockedExpressions());
}
}
@ -95,11 +90,7 @@ public class AggStateFunctionBuilder extends FunctionBuilder {
Expression arg = (Expression) arguments.get(0);
AggStateType type = (AggStateType) arg.getDataType();
List<Expression> nestedArgumens = type.getSubTypes().stream().map(t -> {
return new SlotReference("mocked", t);
}).collect(Collectors.toList());
return (AggregateFunction) nestedBuilder.build(nestedName, nestedArgumens);
return (AggregateFunction) nestedBuilder.build(nestedName, type.getMockedExpressions());
}
@Override

View File

@ -32,7 +32,6 @@ import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* AggState combinator merge
@ -47,13 +46,12 @@ public class MergeCombinator extends AggregateFunction
super(nested.getName() + AggStateFunctionBuilder.MERGE_SUFFIX, arguments);
this.nested = Objects.requireNonNull(nested, "nested can not be null");
inputType = new AggStateType(nested.getName(), nested.getArgumentsTypes(),
nested.getArguments().stream().map(Expression::nullable).collect(Collectors.toList()));
inputType = (AggStateType) arguments.get(0).getDataType();
}
@Override
public MergeCombinator withChildren(List<Expression> children) {
return new MergeCombinator(children, nested.withChildren(children));
return new MergeCombinator(children, nested);
}
@Override

View File

@ -59,7 +59,7 @@ public class StateCombinator extends ScalarFunction
@Override
public StateCombinator withChildren(List<Expression> children) {
return new StateCombinator(children, nested.withChildren(children));
return new StateCombinator(children, nested);
}
@Override

View File

@ -32,7 +32,6 @@ import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* AggState combinator union
@ -47,13 +46,12 @@ public class UnionCombinator extends AggregateFunction
super(nested.getName() + AggStateFunctionBuilder.UNION_SUFFIX, arguments);
this.nested = Objects.requireNonNull(nested, "nested can not be null");
inputType = new AggStateType(nested.getName(), nested.getArgumentsTypes(),
nested.getArguments().stream().map(Expression::nullable).collect(Collectors.toList()));
inputType = (AggStateType) arguments.get(0).getDataType();
}
@Override
public UnionCombinator withChildren(List<Expression> children) {
return new UnionCombinator(children, nested.withChildren(children));
return new UnionCombinator(children, nested);
}
@Override

View File

@ -19,11 +19,14 @@ package org.apache.doris.nereids.types;
import org.apache.doris.analysis.Expr;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@ -56,8 +59,12 @@ public class AggStateType extends DataType {
this.functionName = functionName;
}
public List<DataType> getSubTypes() {
return subTypes;
public List<Expression> getMockedExpressions() {
List<Expression> result = new ArrayList<Expression>();
for (int i = 0; i < subTypes.size(); i++) {
result.add(new SlotReference("mocked", subTypes.get(i), subTypeNullables.get(i)));
}
return result;
}
@Override

View File

@ -264,7 +264,7 @@ public class ExprTest {
DataInputStream dis = new DataInputStream(new FileInputStream(file));
Expr readExpr = Expr.readIn(dis);
Assert.assertTrue(readExpr instanceof ArithmeticExpr);
Assert.assertEquals("cos(1) + 100 / 200", readExpr.toSql());
Assert.assertEquals("(cos(1) + (100 / 200))", readExpr.toSql());
// 3. delete files
dis.close();

View File

@ -150,13 +150,13 @@ public class LoadStmtTest {
columnDescs.descs = columns1;
columnDescs.isColumnDescsRewrited = false;
Load.rewriteColumns(columnDescs);
String orig = "`c1` + 1 + 1";
String orig = "((`c1` + 1) + 1)";
Assert.assertEquals(orig, columns1.get(4).getExpr().toString());
List<ImportColumnDesc> columns2 = getColumns("c1,c2,c3,tmp_c5 = tmp_c4+1, tmp_c4=c1 + 1");
columnDescs.descs = columns2;
columnDescs.isColumnDescsRewrited = false;
String orig2 = "`tmp_c4` + 1";
String orig2 = "(`tmp_c4` + 1)";
Load.rewriteColumns(columnDescs);
Assert.assertEquals(orig2, columns2.get(3).getExpr().toString());

View File

@ -180,7 +180,7 @@ public class S3TvfLoadStmtTest {
Deencapsulation.invoke(s3TvfLoadStmt, "rewriteExpr", columnsDescList);
Assert.assertEquals(columnsDescList.size(), 5);
final String orig4 = "upper(`c1`) + 1 + 1";
final String orig4 = "((upper(`c1`) + 1) + 1)";
Assert.assertEquals(orig4, columnsDescList.get(4).getExpr().toString());
final List<ImportColumnDesc> filterColumns = Deencapsulation.invoke(s3TvfLoadStmt,

View File

@ -67,9 +67,9 @@ public class RepeatNodeTest extends TestWithFeService {
String sql2 = "select /*+ SET_VAR(enable_nereids_planner=false) */ (id + 1) id_, name, sum(cost) from mycost group by grouping sets((id_, name),());";
String explainString2 = getSQLPlanOrErrorMsg("explain " + sql2);
System.out.println(explainString2);
Assertions.assertTrue(explainString2.contains("exprs: (`id` + 1), `name`, `cost`"));
Assertions.assertTrue(explainString2.contains("exprs: ((`id` + 1)), `name`, `cost`"));
Assertions.assertTrue(
explainString2.contains(" output slots: `(`id` + 1)`, ``name``, ``cost``, ``GROUPING_ID``"));
explainString2.contains(" output slots: `((`id` + 1))`, ``name``, ``cost``, ``GROUPING_ID``"));
String sql3 = "select /*+ SET_VAR(enable_nereids_planner=false) */ 1 as id_, name, sum(cost) from mycost group by grouping sets((id_, name),());";
String explainString3 = getSQLPlanOrErrorMsg("explain " + sql3);