[fix](Nereids) support nested complex type literal (#25287)

This commit is contained in:
morrySnow
2023-10-12 14:17:38 +08:00
committed by GitHub
parent 2664d1cffb
commit a0d3206d78
15 changed files with 289 additions and 94 deletions

View File

@ -40,7 +40,7 @@ public class ArrayLiteral extends LiteralExpr {
children = new ArrayList<>();
}
public ArrayLiteral(Type type, LiteralExpr... exprs) throws AnalysisException {
public ArrayLiteral(Type type, LiteralExpr... exprs) {
this.type = type;
children = new ArrayList<>(Arrays.asList(exprs));
analysisDone();

View File

@ -224,8 +224,6 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySlice;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Char;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ConvertTo;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateMap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateStruct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DayCeil;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DayFloor;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DaysAdd;
@ -264,6 +262,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.YearFloor;
import org.apache.doris.nereids.trees.expressions.functions.scalar.YearsAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.YearsDiff;
import org.apache.doris.nereids.trees.expressions.functions.scalar.YearsSub;
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
@ -276,9 +275,11 @@ import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Interval;
import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.MapLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StructLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.trees.plans.JoinHint;
@ -1759,22 +1760,49 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
return sb.toString();
}
@Override
public Object visitArrayLiteral(ArrayLiteralContext ctx) {
Literal[] items = ctx.items.stream().<Literal>map(this::typedVisit).toArray(Literal[]::new);
return new Array(items);
/**
* cast all items to same types.
* TODO remove this function after we refactor type coercion.
*/
private List<Literal> typeCoercionItems(List<Literal> items) {
DataType dataType = new Array(items.toArray(new Literal[0])).expectedInputTypes().get(0);
return items.stream()
.map(item -> item.checkedCastTo(dataType))
.map(Literal.class::cast)
.collect(Collectors.toList());
}
@Override
public Object visitMapLiteral(MapLiteralContext ctx) {
Literal[] items = ctx.items.stream().<Literal>map(this::typedVisit).toArray(Literal[]::new);
return new CreateMap(items);
public ArrayLiteral visitArrayLiteral(ArrayLiteralContext ctx) {
List<Literal> items = ctx.items.stream().<Literal>map(this::typedVisit).collect(Collectors.toList());
if (items.isEmpty()) {
return new ArrayLiteral(items);
}
return new ArrayLiteral(typeCoercionItems(items));
}
@Override
public MapLiteral visitMapLiteral(MapLiteralContext ctx) {
List<Literal> items = ctx.items.stream().<Literal>map(this::typedVisit).collect(Collectors.toList());
if (items.size() % 2 != 0) {
throw new ParseException("map can't be odd parameters, need even parameters", ctx);
}
List<Literal> keys = Lists.newArrayList();
List<Literal> values = Lists.newArrayList();
for (int i = 0; i < items.size(); i++) {
if (i % 2 == 0) {
keys.add(items.get(i));
} else {
values.add(items.get(i));
}
}
return new MapLiteral(typeCoercionItems(keys), typeCoercionItems(values));
}
@Override
public Object visitStructLiteral(StructLiteralContext ctx) {
Literal[] items = ctx.items.stream().<Literal>map(this::typedVisit).toArray(Literal[]::new);
return new CreateStruct(items);
List<Literal> fields = ctx.items.stream().<Literal>map(this::typedVisit).collect(Collectors.toList());
return new StructLiteral(fields);
}
@Override

View File

@ -507,7 +507,8 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
return checkedExpr.get();
}
List<Literal> arguments = (List) array.getArguments();
return new ArrayLiteral(arguments);
// we should pass dataType to constructor because arguments maybe empty
return new ArrayLiteral(arguments, array.getDataType());
}
@Override

View File

@ -336,7 +336,7 @@ public class FunctionBinder extends AbstractExpressionRewriteRule {
public Expression visitCast(Cast cast, ExpressionRewriteContext context) {
cast = (Cast) super.visitCast(cast, context);
// NOTICE: just for compatibility with legacy planner.
if (cast.child().getDataType() instanceof ArrayType || cast.getDataType() instanceof ArrayType) {
if (cast.child().getDataType().isComplexType() || cast.getDataType().isComplexType()) {
TypeCoercionUtils.checkCanCastTo(cast.child().getDataType(), cast.getDataType());
}
return cast;

View File

@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
@ -76,7 +75,9 @@ public class Array extends ScalarFunction
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasCharacterType));
List<DataType> needTypeCoercion = Lists.newArrayList(Sets.newHashSet(partitioned.get(true)));
if (needTypeCoercion.size() > 1 || !partitioned.get(false).isEmpty()) {
needTypeCoercion = Lists.newArrayList(StringType.INSTANCE);
needTypeCoercion = needTypeCoercion.stream()
.map(TypeCoercionUtils::replaceCharacterToString)
.collect(Collectors.toList());
}
needTypeCoercion.addAll(partitioned.get(false));
return needTypeCoercion.stream()

View File

@ -19,17 +19,22 @@ package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NullType;
import com.google.common.collect.ImmutableList;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/** ArrayLiteral */
/**
* ArrayLiteral
*/
public class ArrayLiteral extends Literal {
private final List<Literal> items;
@ -38,15 +43,16 @@ public class ArrayLiteral extends Literal {
* construct array literal
*/
public ArrayLiteral(List<Literal> items) {
super(computeDataType(items));
this.items = items.stream()
.map(i -> {
if (i instanceof NullLiteral) {
DataType type = ((ArrayType) (this.getDataType())).getItemType();
return new NullLiteral(type);
}
return i;
}).collect(ImmutableList.toImmutableList());
super(ArrayType.of(CollectionUtils.isEmpty(items) ? NullType.INSTANCE : items.get(0).getDataType()));
this.items = ImmutableList.copyOf(Objects.requireNonNull(items, "items should not null"));
}
/**
* when items is empty, we could not get dataType from items, so we need pass dataType explicitly.
*/
public ArrayLiteral(List<Literal> items, DataType dataType) {
super(dataType);
this.items = ImmutableList.copyOf(Objects.requireNonNull(items, "items should not null"));
}
@Override
@ -56,17 +62,24 @@ public class ArrayLiteral extends Literal {
@Override
public LiteralExpr toLegacyLiteral() {
if (items.isEmpty()) {
return new org.apache.doris.analysis.ArrayLiteral();
LiteralExpr[] itemExprs = items.stream()
.map(Literal::toLegacyLiteral)
.toArray(LiteralExpr[]::new);
return new org.apache.doris.analysis.ArrayLiteral(getDataType().toCatalogDataType(), itemExprs);
}
@Override
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
if (this.dataType.equals(targetType)) {
return this;
} else if (targetType instanceof ArrayType) {
// we should pass dataType to constructor because arguments maybe empty
return new ArrayLiteral(items.stream()
.map(i -> i.uncheckedCastTo(((ArrayType) targetType).getItemType()))
.map(Literal.class::cast)
.collect(ImmutableList.toImmutableList()), targetType);
} else {
LiteralExpr[] itemExprs = items.stream()
.map(Literal::toLegacyLiteral)
.toArray(LiteralExpr[]::new);
try {
return new org.apache.doris.analysis.ArrayLiteral(getDataType().toCatalogDataType(), itemExprs);
} catch (Throwable t) {
throw new AnalysisException(t.getMessage(), t);
}
return super.uncheckedCastTo(targetType);
}
}
@ -90,17 +103,4 @@ public class ArrayLiteral extends Literal {
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayLiteral(this, context);
}
private static DataType computeDataType(List<Literal> items) {
if (items.isEmpty()) {
return ArrayType.SYSTEM_DEFAULT;
}
DataType dataType = NullType.INSTANCE;
for (Literal item : items) {
if (!item.dataType.isNullType()) {
dataType = item.dataType;
}
}
return ArrayType.of(dataType);
}
}

View File

@ -197,6 +197,13 @@ public abstract class Literal extends Expression implements LeafExpression, Comp
@Override
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
if (this.dataType.equals(targetType)) {
return this;
}
if (this instanceof NullLiteral) {
return new NullLiteral(targetType);
}
// TODO support string to complex
String desc = getStringValue();
if (targetType.isBooleanType()) {
if ("0".equals(desc) || "false".equals(desc.toLowerCase(Locale.ROOT))) {

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.MapType;
@ -58,22 +57,13 @@ public class MapLiteral extends Literal {
@Override
public LiteralExpr toLegacyLiteral() {
if (keys.isEmpty()) {
return new org.apache.doris.analysis.MapLiteral();
} else {
List<LiteralExpr> keyExprs = keys.stream()
.map(Literal::toLegacyLiteral)
.collect(Collectors.toList());
List<LiteralExpr> valueExprs = values.stream()
.map(Literal::toLegacyLiteral)
.collect(Collectors.toList());
try {
return new org.apache.doris.analysis.MapLiteral(
getDataType().toCatalogDataType(), keyExprs, valueExprs);
} catch (Throwable t) {
throw new AnalysisException(t.getMessage(), t);
}
}
List<LiteralExpr> keyExprs = keys.stream()
.map(Literal::toLegacyLiteral)
.collect(Collectors.toList());
List<LiteralExpr> valueExprs = values.stream()
.map(Literal::toLegacyLiteral)
.collect(Collectors.toList());
return new org.apache.doris.analysis.MapLiteral(getDataType().toCatalogDataType(), keyExprs, valueExprs);
}
@Override

View File

@ -0,0 +1,107 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.StructField;
import org.apache.doris.nereids.types.StructType;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* struct literal
*/
public class StructLiteral extends Literal {
private final List<Literal> fields;
public StructLiteral() {
super(StructType.SYSTEM_DEFAULT);
this.fields = ImmutableList.of();
}
public StructLiteral(List<Literal> fields) {
super(computeDataType(fields));
this.fields = ImmutableList.copyOf(fields);
}
@Override
public List<Literal> getValue() {
return fields;
}
@Override
public LiteralExpr toLegacyLiteral() {
try {
return new org.apache.doris.analysis.StructLiteral(
fields.stream().map(Literal::toLegacyLiteral).toArray(LiteralExpr[]::new)
);
} catch (Exception e) {
throw new AnalysisException(e.getMessage(), e);
}
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
StructLiteral that = (StructLiteral) o;
return Objects.equals(fields, that.fields);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), fields);
}
@Override
public String toString() {
return toSql();
}
@Override
public String toSql() {
return "{" + fields.stream().map(Literal::toSql).collect(Collectors.joining(",")) + "}";
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitStructLiteral(this, context);
}
private static StructType computeDataType(List<Literal> fields) {
ImmutableList.Builder<StructField> structFields = ImmutableList.builder();
for (int i = 0; i < fields.size(); i++) {
structFields.add(new StructField(String.valueOf(i + 1), fields.get(i).getDataType(), true, ""));
}
return new StructType(structFields.build());
}
}

View File

@ -107,6 +107,7 @@ import org.apache.doris.nereids.trees.expressions.literal.MapLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StructLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
@ -309,6 +310,10 @@ public abstract class ExpressionVisitor<R, C>
return visitLiteral(mapLiteral, context);
}
public R visitStructLiteral(StructLiteral structLiteral, C context) {
return visitLiteral(structLiteral, context);
}
public R visitCompoundPredicate(CompoundPredicate compoundPredicate, C context) {
return visitBinaryOperator(compoundPredicate, context);
}

View File

@ -245,7 +245,7 @@ public abstract class LogicalSetOperation extends AbstractLogicalPlan implements
boolean nullable = leftFields.get(i).isNullable() || rightFields.get(i).isNullable();
DataType commonType = getAssignmentCompatibleType(
leftFields.get(i).getDataType(), rightFields.get(i).getDataType());
StructField commonField = leftFields.get(i).withDataTypeAndNulalble(commonType, nullable);
StructField commonField = leftFields.get(i).withDataTypeAndNullable(commonType, nullable);
commonFields.add(commonField);
}
return new StructType(commonFields.build());

View File

@ -19,12 +19,13 @@ package org.apache.doris.nereids.types;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.types.coercion.PrimitiveType;
/**
* Json type in Nereids.
*/
@Developing
public class JsonType extends DataType {
public class JsonType extends PrimitiveType {
public static final JsonType INSTANCE = new JsonType();
@ -62,9 +63,4 @@ public class JsonType extends DataType {
public int width() {
return WIDTH;
}
@Override
public String toSql() {
return "JSON";
}
}

View File

@ -71,7 +71,7 @@ public class StructField {
return new StructField(name, dataType, nullable, comment);
}
public StructField withDataTypeAndNulalble(DataType dataType, boolean nullable) {
public StructField withDataTypeAndNullable(DataType dataType, boolean nullable) {
return new StructField(name, dataType, nullable, comment);
}

View File

@ -265,7 +265,6 @@ public class TypeCoercionUtils {
/**
* return ture if datatype has character type in it, cannot use instance of CharacterType because of complex type.
*/
@Developing
public static boolean hasCharacterType(DataType dataType) {
if (dataType instanceof ArrayType) {
return hasCharacterType(((ArrayType) dataType).getItemType());
@ -278,6 +277,27 @@ public class TypeCoercionUtils {
return dataType instanceof CharacterType;
}
/**
* replace all character types to string for correct type coercion
*/
public static DataType replaceCharacterToString(DataType dataType) {
if (dataType instanceof ArrayType) {
return ArrayType.of(replaceCharacterToString(((ArrayType) dataType).getItemType()));
} else if (dataType instanceof MapType) {
return MapType.of(replaceCharacterToString(((MapType) dataType).getKeyType()),
replaceCharacterToString(((MapType) dataType).getValueType()));
} else if (dataType instanceof StructType) {
List<StructField> newFields = ((StructType) dataType).getFields().stream()
.map(f -> f.withDataType(replaceCharacterToString(f.getDataType())))
.collect(ImmutableList.toImmutableList());
return new StructType(newFields);
} else if (dataType instanceof CharacterType) {
return StringType.INSTANCE;
} else {
return dataType;
}
}
/**
* The type used for arithmetic operations.
*/
@ -825,6 +845,10 @@ public class TypeCoercionUtils {
// same type
if (left.getDataType().equals(right.getDataType())) {
if (!supportCompare(left.getDataType())) {
throw new AnalysisException("data type " + left.getDataType()
+ " could not used in ComparisonPredicate " + comparisonPredicate.toSql());
}
return comparisonPredicate.withChildren(left, right);
}
@ -837,6 +861,10 @@ public class TypeCoercionUtils {
Optional<DataType> commonType = findWiderTypeForTwoForComparison(
left.getDataType(), right.getDataType(), false);
if (commonType.isPresent()) {
if (!supportCompare(commonType.get())) {
throw new AnalysisException("data type " + commonType.get()
+ " could not used in ComparisonPredicate " + comparisonPredicate.toSql());
}
left = castIfNotSameType(left, commonType.get());
right = castIfNotSameType(right, commonType.get());
}
@ -852,6 +880,10 @@ public class TypeCoercionUtils {
if (inPredicate.getOptions().stream().map(Expression::getDataType)
.allMatch(dt -> dt.equals(inPredicate.getCompareExpr().getDataType()))) {
if (!supportCompare(inPredicate.getCompareExpr().getDataType())) {
throw new AnalysisException("data type " + inPredicate.getCompareExpr().getDataType()
+ " could not used in InPredicate " + inPredicate.toSql());
}
return inPredicate;
}
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonTypeForComparison(
@ -859,6 +891,10 @@ public class TypeCoercionUtils {
.stream()
.map(Expression::getDataType).collect(Collectors.toList()),
true);
if (optionalCommonType.isPresent() && !supportCompare(optionalCommonType.get())) {
throw new AnalysisException("data type " + optionalCommonType.get()
+ " could not used in InPredicate " + inPredicate.toSql());
}
return optionalCommonType
.map(commonType -> {
@ -969,6 +1005,11 @@ public class TypeCoercionUtils {
Map<Boolean, List<DataType>> partitioned = dataTypes.stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasCharacterType));
List<DataType> needTypeCoercion = Lists.newArrayList(Sets.newHashSet(partitioned.get(true)));
if (needTypeCoercion.size() > 1 || !partitioned.get(false).isEmpty()) {
needTypeCoercion = needTypeCoercion.stream()
.map(TypeCoercionUtils::replaceCharacterToString)
.collect(Collectors.toList());
}
needTypeCoercion.addAll(partitioned.get(false));
return needTypeCoercion.stream().map(Optional::of).reduce(Optional.of(NullType.INSTANCE),
(r, c) -> {
@ -1175,6 +1216,11 @@ public class TypeCoercionUtils {
Map<Boolean, List<DataType>> partitioned = dataTypes.stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasCharacterType));
List<DataType> needTypeCoercion = Lists.newArrayList(Sets.newHashSet(partitioned.get(true)));
if (needTypeCoercion.size() > 1 || !partitioned.get(false).isEmpty()) {
needTypeCoercion = needTypeCoercion.stream()
.map(TypeCoercionUtils::replaceCharacterToString)
.collect(Collectors.toList());
}
needTypeCoercion.addAll(partitioned.get(false));
return needTypeCoercion.stream().map(Optional::of).reduce(Optional.of(NullType.INSTANCE),
(r, c) -> {
@ -1493,4 +1539,17 @@ public class TypeCoercionUtils {
return binaryArithmetic.withChildren(castIfNotSameType(left, dt1),
castIfNotSameType(right, dt2));
}
private static boolean supportCompare(DataType dataType) {
if (!(dataType instanceof PrimitiveType)) {
return false;
}
if (dataType.isObjectType()) {
return false;
}
if (dataType instanceof JsonType) {
return false;
}
return true;
}
}