[fix](Nereids) support complex literal cast in fe (#29599)
This commit is contained in:
@ -46,11 +46,26 @@ public class StructLiteral extends LiteralExpr {
|
||||
public StructLiteral(LiteralExpr... exprs) throws AnalysisException {
|
||||
type = new StructType();
|
||||
children = new ArrayList<>();
|
||||
for (int i = 0; i < exprs.length; i++) {
|
||||
if (!StructType.STRUCT.supportSubType(exprs[i].getType())) {
|
||||
throw new AnalysisException("Invalid element type in STRUCT: " + exprs[i].getType());
|
||||
}
|
||||
((StructType) type).addField(
|
||||
new StructField(StructField.DEFAULT_FIELD_NAME + (i + 1), exprs[i].getType()));
|
||||
children.add(exprs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* for nereids
|
||||
*/
|
||||
public StructLiteral(Type type, LiteralExpr... exprs) throws AnalysisException {
|
||||
this.type = type;
|
||||
this.children = new ArrayList<>();
|
||||
for (LiteralExpr expr : exprs) {
|
||||
if (!StructType.STRUCT.supportSubType(expr.getType())) {
|
||||
throw new AnalysisException("Invalid element type in STRUCT: " + expr.getType());
|
||||
}
|
||||
((StructType) type).addField(new StructField(expr.getType()));
|
||||
children.add(expr);
|
||||
}
|
||||
}
|
||||
@ -104,8 +119,8 @@ public class StructLiteral extends LiteralExpr {
|
||||
// same with be default field index start with 1
|
||||
for (int i = 0; i < children.size(); i++) {
|
||||
Expr child = children.get(i);
|
||||
String fieldName = new StructField(child.getType()).getName();
|
||||
list.add("\"" + fieldName + (i + 1) + "\": " + getStringLiteralForComplexType(child));
|
||||
list.add("\"" + ((StructType) type).getFields().get(i).getName() + "\": "
|
||||
+ getStringLiteralForComplexType(child));
|
||||
}
|
||||
return "{" + StringUtils.join(list, ", ") + "}";
|
||||
}
|
||||
|
||||
@ -92,7 +92,7 @@ public class Cast extends Expression implements UnaryExpression {
|
||||
|
||||
@Override
|
||||
public String toSql() throws UnboundException {
|
||||
return "cast(" + child().toSql() + " as " + targetType + ")";
|
||||
return "cast(" + child().toSql() + " as " + targetType.toSql() + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -22,9 +22,9 @@ import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.StructLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.StructField;
|
||||
import org.apache.doris.nereids.types.StructType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -66,11 +66,7 @@ public class CreateStruct extends ScalarFunction
|
||||
if (arity() == 0) {
|
||||
return SIGNATURES;
|
||||
} else {
|
||||
ImmutableList.Builder<StructField> structFields = ImmutableList.builder();
|
||||
for (int i = 0; i < arity(); i++) {
|
||||
structFields.add(new StructField(String.valueOf(i + 1), children.get(i).getDataType(), true, ""));
|
||||
}
|
||||
return ImmutableList.of(FunctionSignature.ret(new StructType(structFields.build()))
|
||||
return ImmutableList.of(FunctionSignature.ret(StructLiteral.computeDataType(children))
|
||||
.args(children.stream().map(ExpressionTrait::getDataType).toArray(DataType[]::new)));
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,6 +25,7 @@ import org.apache.doris.nereids.types.ArrayType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.NullType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@ -43,8 +44,7 @@ public class ArrayLiteral extends Literal {
|
||||
* construct array literal
|
||||
*/
|
||||
public ArrayLiteral(List<Literal> items) {
|
||||
super(ArrayType.of(CollectionUtils.isEmpty(items) ? NullType.INSTANCE : items.get(0).getDataType()));
|
||||
this.items = ImmutableList.copyOf(Objects.requireNonNull(items, "items should not null"));
|
||||
this(items, ArrayType.of(CollectionUtils.isEmpty(items) ? NullType.INSTANCE : items.get(0).getDataType()));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -52,6 +52,8 @@ public class ArrayLiteral extends Literal {
|
||||
*/
|
||||
public ArrayLiteral(List<Literal> items, DataType dataType) {
|
||||
super(dataType);
|
||||
Preconditions.checkArgument(dataType instanceof ArrayType,
|
||||
"dataType should be ArrayType, but we meet %s", dataType);
|
||||
this.items = ImmutableList.copyOf(Objects.requireNonNull(items, "items should not null"));
|
||||
}
|
||||
|
||||
|
||||
@ -18,6 +18,8 @@
|
||||
package org.apache.doris.nereids.trees.expressions.literal;
|
||||
|
||||
import org.apache.doris.analysis.LiteralExpr;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.MapType;
|
||||
@ -43,9 +45,15 @@ public class MapLiteral extends Literal {
|
||||
}
|
||||
|
||||
public MapLiteral(List<Literal> keys, List<Literal> values) {
|
||||
super(computeDataType(keys, values));
|
||||
this(keys, values, computeDataType(keys, values));
|
||||
}
|
||||
|
||||
private MapLiteral(List<Literal> keys, List<Literal> values, DataType dataType) {
|
||||
super(dataType);
|
||||
this.keys = ImmutableList.copyOf(Objects.requireNonNull(keys, "keys should not be null"));
|
||||
this.values = ImmutableList.copyOf(Objects.requireNonNull(values, "values should not be null"));
|
||||
Preconditions.checkArgument(dataType instanceof MapType,
|
||||
"dataType should be MapType, but we meet %s", dataType);
|
||||
Preconditions.checkArgument(keys.size() == values.size(),
|
||||
"key size %s is not equal to value size %s", keys.size(), values.size());
|
||||
}
|
||||
@ -55,6 +63,28 @@ public class MapLiteral extends Literal {
|
||||
return ImmutableList.of(keys, values);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
|
||||
if (this.dataType.equals(targetType)) {
|
||||
return this;
|
||||
} else if (targetType instanceof MapType) {
|
||||
// we should pass dataType to constructor because arguments maybe empty
|
||||
return new MapLiteral(
|
||||
keys.stream()
|
||||
.map(k -> k.uncheckedCastTo(((MapType) targetType).getKeyType()))
|
||||
.map(Literal.class::cast)
|
||||
.collect(ImmutableList.toImmutableList()),
|
||||
values.stream()
|
||||
.map(v -> v.uncheckedCastTo(((MapType) targetType).getValueType()))
|
||||
.map(Literal.class::cast)
|
||||
.collect(ImmutableList.toImmutableList()),
|
||||
targetType
|
||||
);
|
||||
} else {
|
||||
return super.uncheckedCastTo(targetType);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public LiteralExpr toLegacyLiteral() {
|
||||
List<LiteralExpr> keyExprs = keys.stream()
|
||||
|
||||
@ -19,15 +19,17 @@ package org.apache.doris.nereids.trees.expressions.literal;
|
||||
|
||||
import org.apache.doris.analysis.LiteralExpr;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.StructField;
|
||||
import org.apache.doris.nereids.types.StructType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* struct literal
|
||||
@ -42,8 +44,17 @@ public class StructLiteral extends Literal {
|
||||
}
|
||||
|
||||
public StructLiteral(List<Literal> fields) {
|
||||
super(computeDataType(fields));
|
||||
this.fields = ImmutableList.copyOf(fields);
|
||||
this(fields, computeDataType(fields));
|
||||
}
|
||||
|
||||
private StructLiteral(List<Literal> fields, DataType dataType) {
|
||||
super(dataType);
|
||||
this.fields = ImmutableList.copyOf(Objects.requireNonNull(fields, "fields should not be null"));
|
||||
Preconditions.checkArgument(dataType instanceof StructType,
|
||||
"dataType should be StructType, but we meet %s", dataType);
|
||||
Preconditions.checkArgument(fields.size() == ((StructType) dataType).getFields().size(),
|
||||
"fields size is not same with dataType size. %s vs %s",
|
||||
fields.size(), ((StructType) dataType).getFields().size());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -51,10 +62,30 @@ public class StructLiteral extends Literal {
|
||||
return fields;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
|
||||
if (this.dataType.equals(targetType)) {
|
||||
return this;
|
||||
} else if (targetType instanceof StructType) {
|
||||
// we should pass dataType to constructor because arguments maybe empty
|
||||
if (((StructType) targetType).getFields().size() != this.fields.size()) {
|
||||
return super.uncheckedCastTo(targetType);
|
||||
}
|
||||
ImmutableList.Builder<Literal> newLiterals = ImmutableList.builder();
|
||||
for (int i = 0; i < fields.size(); i++) {
|
||||
newLiterals.add((Literal) fields.get(i)
|
||||
.uncheckedCastTo(((StructType) targetType).getFields().get(i).getDataType()));
|
||||
}
|
||||
return new StructLiteral(newLiterals.build(), targetType);
|
||||
} else {
|
||||
return super.uncheckedCastTo(targetType);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public LiteralExpr toLegacyLiteral() {
|
||||
try {
|
||||
return new org.apache.doris.analysis.StructLiteral(
|
||||
return new org.apache.doris.analysis.StructLiteral(dataType.toCatalogDataType(),
|
||||
fields.stream().map(Literal::toLegacyLiteral).toArray(LiteralExpr[]::new)
|
||||
);
|
||||
} catch (Exception e) {
|
||||
@ -89,7 +120,18 @@ public class StructLiteral extends Literal {
|
||||
|
||||
@Override
|
||||
public String toSql() {
|
||||
return "{" + fields.stream().map(Literal::toSql).collect(Collectors.joining(",")) + "}";
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("STRUCT(");
|
||||
for (int i = 0; i < fields.size(); i++) {
|
||||
if (i != 0) {
|
||||
sb.append(",");
|
||||
}
|
||||
sb.append("'").append(((StructType) dataType).getFields().get(i).getName()).append("'");
|
||||
sb.append(":");
|
||||
sb.append(fields.get(i).toSql());
|
||||
}
|
||||
sb.append(")");
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -97,10 +139,10 @@ public class StructLiteral extends Literal {
|
||||
return visitor.visitStructLiteral(this, context);
|
||||
}
|
||||
|
||||
private static StructType computeDataType(List<Literal> fields) {
|
||||
public static StructType computeDataType(List<? extends Expression> fields) {
|
||||
ImmutableList.Builder<StructField> structFields = ImmutableList.builder();
|
||||
for (int i = 0; i < fields.size(); i++) {
|
||||
structFields.add(new StructField(String.valueOf(i + 1), fields.get(i).getDataType(), true, ""));
|
||||
structFields.add(new StructField("col" + (i + 1), fields.get(i).getDataType(), true, ""));
|
||||
}
|
||||
return new StructType(structFields.build());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user