[fix](Nereids) support complex literal cast in fe (#29599)

This commit is contained in:
morrySnow
2024-01-11 11:15:45 +08:00
committed by yiguolei
parent 771c66c034
commit e93a16ac6e
11 changed files with 132 additions and 45 deletions

View File

@ -46,11 +46,26 @@ public class StructLiteral extends LiteralExpr {
public StructLiteral(LiteralExpr... exprs) throws AnalysisException {
type = new StructType();
children = new ArrayList<>();
for (int i = 0; i < exprs.length; i++) {
if (!StructType.STRUCT.supportSubType(exprs[i].getType())) {
throw new AnalysisException("Invalid element type in STRUCT: " + exprs[i].getType());
}
((StructType) type).addField(
new StructField(StructField.DEFAULT_FIELD_NAME + (i + 1), exprs[i].getType()));
children.add(exprs[i]);
}
}
/**
* for nereids
*/
public StructLiteral(Type type, LiteralExpr... exprs) throws AnalysisException {
this.type = type;
this.children = new ArrayList<>();
for (LiteralExpr expr : exprs) {
if (!StructType.STRUCT.supportSubType(expr.getType())) {
throw new AnalysisException("Invalid element type in STRUCT: " + expr.getType());
}
((StructType) type).addField(new StructField(expr.getType()));
children.add(expr);
}
}
@ -104,8 +119,8 @@ public class StructLiteral extends LiteralExpr {
// same with be default field index start with 1
for (int i = 0; i < children.size(); i++) {
Expr child = children.get(i);
String fieldName = new StructField(child.getType()).getName();
list.add("\"" + fieldName + (i + 1) + "\": " + getStringLiteralForComplexType(child));
list.add("\"" + ((StructType) type).getFields().get(i).getName() + "\": "
+ getStringLiteralForComplexType(child));
}
return "{" + StringUtils.join(list, ", ") + "}";
}

View File

@ -92,7 +92,7 @@ public class Cast extends Expression implements UnaryExpression {
@Override
public String toSql() throws UnboundException {
return "cast(" + child().toSql() + " as " + targetType + ")";
return "cast(" + child().toSql() + " as " + targetType.toSql() + ")";
}
@Override

View File

@ -22,9 +22,9 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.literal.StructLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.StructField;
import org.apache.doris.nereids.types.StructType;
import com.google.common.collect.ImmutableList;
@ -66,11 +66,7 @@ public class CreateStruct extends ScalarFunction
if (arity() == 0) {
return SIGNATURES;
} else {
ImmutableList.Builder<StructField> structFields = ImmutableList.builder();
for (int i = 0; i < arity(); i++) {
structFields.add(new StructField(String.valueOf(i + 1), children.get(i).getDataType(), true, ""));
}
return ImmutableList.of(FunctionSignature.ret(new StructType(structFields.build()))
return ImmutableList.of(FunctionSignature.ret(StructLiteral.computeDataType(children))
.args(children.stream().map(ExpressionTrait::getDataType).toArray(DataType[]::new)));
}
}

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NullType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import org.springframework.util.CollectionUtils;
@ -43,8 +44,7 @@ public class ArrayLiteral extends Literal {
* construct array literal
*/
public ArrayLiteral(List<Literal> items) {
super(ArrayType.of(CollectionUtils.isEmpty(items) ? NullType.INSTANCE : items.get(0).getDataType()));
this.items = ImmutableList.copyOf(Objects.requireNonNull(items, "items should not null"));
this(items, ArrayType.of(CollectionUtils.isEmpty(items) ? NullType.INSTANCE : items.get(0).getDataType()));
}
/**
@ -52,6 +52,8 @@ public class ArrayLiteral extends Literal {
*/
public ArrayLiteral(List<Literal> items, DataType dataType) {
super(dataType);
Preconditions.checkArgument(dataType instanceof ArrayType,
"dataType should be ArrayType, but we meet %s", dataType);
this.items = ImmutableList.copyOf(Objects.requireNonNull(items, "items should not null"));
}

View File

@ -18,6 +18,8 @@
package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.MapType;
@ -43,9 +45,15 @@ public class MapLiteral extends Literal {
}
public MapLiteral(List<Literal> keys, List<Literal> values) {
super(computeDataType(keys, values));
this(keys, values, computeDataType(keys, values));
}
private MapLiteral(List<Literal> keys, List<Literal> values, DataType dataType) {
super(dataType);
this.keys = ImmutableList.copyOf(Objects.requireNonNull(keys, "keys should not be null"));
this.values = ImmutableList.copyOf(Objects.requireNonNull(values, "values should not be null"));
Preconditions.checkArgument(dataType instanceof MapType,
"dataType should be MapType, but we meet %s", dataType);
Preconditions.checkArgument(keys.size() == values.size(),
"key size %s is not equal to value size %s", keys.size(), values.size());
}
@ -55,6 +63,28 @@ public class MapLiteral extends Literal {
return ImmutableList.of(keys, values);
}
@Override
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
if (this.dataType.equals(targetType)) {
return this;
} else if (targetType instanceof MapType) {
// we should pass dataType to constructor because arguments maybe empty
return new MapLiteral(
keys.stream()
.map(k -> k.uncheckedCastTo(((MapType) targetType).getKeyType()))
.map(Literal.class::cast)
.collect(ImmutableList.toImmutableList()),
values.stream()
.map(v -> v.uncheckedCastTo(((MapType) targetType).getValueType()))
.map(Literal.class::cast)
.collect(ImmutableList.toImmutableList()),
targetType
);
} else {
return super.uncheckedCastTo(targetType);
}
}
@Override
public LiteralExpr toLegacyLiteral() {
List<LiteralExpr> keyExprs = keys.stream()

View File

@ -19,15 +19,17 @@ package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.StructField;
import org.apache.doris.nereids.types.StructType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* struct literal
@ -42,8 +44,17 @@ public class StructLiteral extends Literal {
}
public StructLiteral(List<Literal> fields) {
super(computeDataType(fields));
this.fields = ImmutableList.copyOf(fields);
this(fields, computeDataType(fields));
}
private StructLiteral(List<Literal> fields, DataType dataType) {
super(dataType);
this.fields = ImmutableList.copyOf(Objects.requireNonNull(fields, "fields should not be null"));
Preconditions.checkArgument(dataType instanceof StructType,
"dataType should be StructType, but we meet %s", dataType);
Preconditions.checkArgument(fields.size() == ((StructType) dataType).getFields().size(),
"fields size is not same with dataType size. %s vs %s",
fields.size(), ((StructType) dataType).getFields().size());
}
@Override
@ -51,10 +62,30 @@ public class StructLiteral extends Literal {
return fields;
}
@Override
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
if (this.dataType.equals(targetType)) {
return this;
} else if (targetType instanceof StructType) {
// we should pass dataType to constructor because arguments maybe empty
if (((StructType) targetType).getFields().size() != this.fields.size()) {
return super.uncheckedCastTo(targetType);
}
ImmutableList.Builder<Literal> newLiterals = ImmutableList.builder();
for (int i = 0; i < fields.size(); i++) {
newLiterals.add((Literal) fields.get(i)
.uncheckedCastTo(((StructType) targetType).getFields().get(i).getDataType()));
}
return new StructLiteral(newLiterals.build(), targetType);
} else {
return super.uncheckedCastTo(targetType);
}
}
@Override
public LiteralExpr toLegacyLiteral() {
try {
return new org.apache.doris.analysis.StructLiteral(
return new org.apache.doris.analysis.StructLiteral(dataType.toCatalogDataType(),
fields.stream().map(Literal::toLegacyLiteral).toArray(LiteralExpr[]::new)
);
} catch (Exception e) {
@ -89,7 +120,18 @@ public class StructLiteral extends Literal {
@Override
public String toSql() {
return "{" + fields.stream().map(Literal::toSql).collect(Collectors.joining(",")) + "}";
StringBuilder sb = new StringBuilder();
sb.append("STRUCT(");
for (int i = 0; i < fields.size(); i++) {
if (i != 0) {
sb.append(",");
}
sb.append("'").append(((StructType) dataType).getFields().get(i).getName()).append("'");
sb.append(":");
sb.append(fields.get(i).toSql());
}
sb.append(")");
return sb.toString();
}
@Override
@ -97,10 +139,10 @@ public class StructLiteral extends Literal {
return visitor.visitStructLiteral(this, context);
}
private static StructType computeDataType(List<Literal> fields) {
public static StructType computeDataType(List<? extends Expression> fields) {
ImmutableList.Builder<StructField> structFields = ImmutableList.builder();
for (int i = 0; i < fields.size(); i++) {
structFields.add(new StructField(String.valueOf(i + 1), fields.get(i).getDataType(), true, ""));
structFields.add(new StructField("col" + (i + 1), fields.get(i).getDataType(), true, ""));
}
return new StructType(structFields.build());
}