[feat](Nereids): add is null predicate for the first partition when updating mv by partition (#32463)

This commit is contained in:
谢健
2024-03-22 10:25:49 +08:00
committed by yiguolei
parent 3f36aa2d48
commit 8f3f9a53be
2 changed files with 102 additions and 67 deletions

View File

@ -49,12 +49,15 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.RelationUtil;
import org.apache.doris.qe.ConnectContext;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Range;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -111,11 +114,27 @@ public class UpdateMvByPartitionCommand extends InsertOverwriteTableCommand {
return builder.build();
}
private static Set<Expression> constructPredicates(Set<PartitionItem> partitions, String colName) {
/**
* construct predicates for partition items, the min key is the min key of range items.
* For list partition or less than partition items, the min key is null.
*/
@VisibleForTesting
public static Set<Expression> constructPredicates(Set<PartitionItem> partitions, String colName) {
Set<Expression> predicates = new HashSet<>();
UnboundSlot slot = new UnboundSlot(colName);
return partitions.stream()
.map(item -> convertPartitionItemToPredicate(item, slot))
.collect(ImmutableSet.toImmutableSet());
if (partitions.isEmpty()) {
return Sets.newHashSet(BooleanLiteral.TRUE);
}
if (partitions.iterator().next() instanceof ListPartitionItem) {
for (PartitionItem item : partitions) {
predicates.add(convertListPartitionToIn(item, slot));
}
} else {
for (PartitionItem item : partitions) {
predicates.add(convertRangePartitionToCompare(item, slot));
}
}
return predicates;
}
private static Expression convertPartitionKeyToLiteral(PartitionKey key) {
@ -123,42 +142,48 @@ public class UpdateMvByPartitionCommand extends InsertOverwriteTableCommand {
Type.fromPrimitiveType(key.getTypes().get(0)));
}
private static Expression convertPartitionItemToPredicate(PartitionItem item, Slot col) {
if (item instanceof ListPartitionItem) {
List<Expression> inValues = ((ListPartitionItem) item).getItems().stream()
.map(UpdateMvByPartitionCommand::convertPartitionKeyToLiteral)
.collect(ImmutableList.toImmutableList());
List<Expression> predicates = new ArrayList<>();
if (inValues.stream().anyMatch(NullLiteral.class::isInstance)) {
inValues = inValues.stream()
.filter(e -> !(e instanceof NullLiteral))
.collect(Collectors.toList());
Expression isNullPredicate = new IsNull(col);
predicates.add(isNullPredicate);
}
if (!inValues.isEmpty()) {
predicates.add(new InPredicate(col, inValues));
}
if (predicates.isEmpty()) {
return BooleanLiteral.of(true);
}
return ExpressionUtils.or(predicates);
} else {
Range<PartitionKey> range = item.getItems();
List<Expression> exprs = new ArrayList<>();
if (range.hasLowerBound() && !range.lowerEndpoint().isMinValue()) {
PartitionKey key = range.lowerEndpoint();
exprs.add(new GreaterThanEqual(col, convertPartitionKeyToLiteral(key)));
}
if (range.hasUpperBound() && !range.upperEndpoint().isMaxValue()) {
PartitionKey key = range.upperEndpoint();
exprs.add(new LessThan(col, convertPartitionKeyToLiteral(key)));
}
if (exprs.isEmpty()) {
return BooleanLiteral.of(true);
}
return ExpressionUtils.and(exprs);
private static Expression convertListPartitionToIn(PartitionItem item, Slot col) {
List<Expression> inValues = ((ListPartitionItem) item).getItems().stream()
.map(UpdateMvByPartitionCommand::convertPartitionKeyToLiteral)
.collect(ImmutableList.toImmutableList());
List<Expression> predicates = new ArrayList<>();
if (inValues.stream().anyMatch(NullLiteral.class::isInstance)) {
inValues = inValues.stream()
.filter(e -> !(e instanceof NullLiteral))
.collect(Collectors.toList());
Expression isNullPredicate = new IsNull(col);
predicates.add(isNullPredicate);
}
if (!inValues.isEmpty()) {
predicates.add(new InPredicate(col, inValues));
}
if (predicates.isEmpty()) {
return BooleanLiteral.of(true);
}
return ExpressionUtils.or(predicates);
}
private static Expression convertRangePartitionToCompare(PartitionItem item, Slot col) {
Range<PartitionKey> range = item.getItems();
List<Expression> expressions = new ArrayList<>();
if (range.hasLowerBound() && !range.lowerEndpoint().isMinValue()) {
PartitionKey key = range.lowerEndpoint();
expressions.add(new GreaterThanEqual(col, convertPartitionKeyToLiteral(key)));
}
if (range.hasUpperBound() && !range.upperEndpoint().isMaxValue()) {
PartitionKey key = range.upperEndpoint();
expressions.add(new LessThan(col, convertPartitionKeyToLiteral(key)));
}
if (expressions.isEmpty()) {
return BooleanLiteral.of(true);
}
Expression predicate = ExpressionUtils.and(expressions);
// The partition without can be the first partition of LESS THAN PARTITIONS
// The null value can insert into this partition, so we need to add or is null condition
if (!range.hasLowerBound()) {
predicate = ExpressionUtils.or(predicate, new IsNull(col));
}
return predicate;
}
static class PredicateAdder extends DefaultPlanRewriter<Map<TableIf, Set<Expression>>> {

View File

@ -20,58 +20,68 @@ package org.apache.doris.nereids.trees.plans.commands;
import org.apache.doris.analysis.PartitionValue;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.ListPartitionItem;
import org.apache.doris.catalog.PartitionItem;
import org.apache.doris.catalog.PartitionKey;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.RangePartitionItem;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.types.IntegerType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Range;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Set;
class UpdateMvByPartitionCommandTest {
@Test
void testMaxMin() throws AnalysisException, NoSuchMethodException, InvocationTargetException,
IllegalAccessException {
Method m = UpdateMvByPartitionCommand.class.getDeclaredMethod("convertPartitionItemToPredicate", PartitionItem.class,
Slot.class);
m.setAccessible(true);
void testFirstPartWithoutLowerBound() throws AnalysisException {
Column column = new Column("a", PrimitiveType.INT);
PartitionKey upper = PartitionKey.createPartitionKey(ImmutableList.of(PartitionValue.MAX_VALUE), ImmutableList.of(column));
PartitionKey lower = PartitionKey.createPartitionKey(ImmutableList.of(new PartitionValue(1L)), ImmutableList.of(column));
Range<PartitionKey> range = Range.closedOpen(lower, upper);
RangePartitionItem rangePartitionItem = new RangePartitionItem(range);
Expression expr = (Expression) m.invoke(null, rangePartitionItem, new SlotReference("s", IntegerType.INSTANCE));
Assertions.assertTrue(expr instanceof GreaterThanEqual);
PartitionKey upper = PartitionKey.createPartitionKey(ImmutableList.of(new PartitionValue(1L)),
ImmutableList.of(column));
Range<PartitionKey> range1 = Range.lessThan(upper);
RangePartitionItem item1 = new RangePartitionItem(range1);
Set<Expression> predicates = UpdateMvByPartitionCommand.constructPredicates(Sets.newHashSet(item1), "s");
Assertions.assertEquals("((s < 1) OR s IS NULL)", predicates.iterator().next().toSql());
}
@Test
void testNull() throws AnalysisException, NoSuchMethodException, InvocationTargetException,
IllegalAccessException {
Method m = UpdateMvByPartitionCommand.class.getDeclaredMethod("convertPartitionItemToPredicate", PartitionItem.class,
Slot.class);
m.setAccessible(true);
void testMaxMin() throws AnalysisException {
Column column = new Column("a", PrimitiveType.INT);
PartitionKey v = PartitionKey.createListPartitionKeyWithTypes(ImmutableList.of(new PartitionValue("NULL", true)), ImmutableList.of(column.getType()), false);
PartitionKey upper = PartitionKey.createPartitionKey(ImmutableList.of(PartitionValue.MAX_VALUE),
ImmutableList.of(column));
PartitionKey lower = PartitionKey.createPartitionKey(ImmutableList.of(new PartitionValue(1L)),
ImmutableList.of(column));
Range<PartitionKey> range = Range.closedOpen(lower, upper);
RangePartitionItem rangePartitionItem = new RangePartitionItem(range);
Set<Expression> predicates = UpdateMvByPartitionCommand.constructPredicates(Sets.newHashSet(rangePartitionItem),
"s");
Expression expr = predicates.iterator().next();
System.out.println(expr.toSql());
Assertions.assertEquals("(s >= 1)", expr.toSql());
}
@Test
void testNull() throws AnalysisException {
Column column = new Column("a", PrimitiveType.INT);
PartitionKey v = PartitionKey.createListPartitionKeyWithTypes(
ImmutableList.of(new PartitionValue("NULL", true)), ImmutableList.of(column.getType()), false);
ListPartitionItem listPartitionItem = new ListPartitionItem(ImmutableList.of(v));
Expression expr = (Expression) m.invoke(null, listPartitionItem, new SlotReference("s", IntegerType.INSTANCE));
Expression expr = UpdateMvByPartitionCommand.constructPredicates(Sets.newHashSet(listPartitionItem), "s")
.iterator().next();
Assertions.assertTrue(expr instanceof IsNull);
PartitionKey v1 = PartitionKey.createListPartitionKeyWithTypes(ImmutableList.of(new PartitionValue("NULL", true)), ImmutableList.of(column.getType()), false);
PartitionKey v2 = PartitionKey.createListPartitionKeyWithTypes(ImmutableList.of(new PartitionValue("1", false)), ImmutableList.of(column.getType()), false);
PartitionKey v1 = PartitionKey.createListPartitionKeyWithTypes(
ImmutableList.of(new PartitionValue("NULL", true)), ImmutableList.of(column.getType()), false);
PartitionKey v2 = PartitionKey.createListPartitionKeyWithTypes(ImmutableList.of(new PartitionValue("1", false)),
ImmutableList.of(column.getType()), false);
listPartitionItem = new ListPartitionItem(ImmutableList.of(v1, v2));
expr = (Expression) m.invoke(null, listPartitionItem, new SlotReference("s", IntegerType.INSTANCE));
expr = UpdateMvByPartitionCommand.constructPredicates(Sets.newHashSet(listPartitionItem), "s").iterator()
.next();
Assertions.assertEquals("(s IS NULL OR s IN (1))", expr.toSql());
}
}