[refactor](Nereids) Refactor and optimize partition pruning (#18003)

the legacy PartitionPruner only support some simple cases, some useful cases not support:
1. can not support evaluate some builtin functions, like `cast(part_column as bigint) = 1`
2. can not prune multi level range partition, for partition `[[('1', 'a'), ('2', 'b'))`, it has some constraints:
    - first_part_column between '1' and '2'
    - if first_part_column = '1' then second_part_column >= 'a'
    - if first_part_column = '2' then second_part_column < 'a'

This pr refactor it and support:
1. use visitor to evaluate function and fold constant
2. if the partition is discrete like int, date, we can expand it and evaluate, e.g `[1, 5)` will be expand to `[1, 2, 3, 4]`
3. support prune multi level range partition, as previously described
4. support evaluate capabilities for a range slot, e.g. datetime range partition `[('2023-03-21 00:00:00'), ('2023-03-21 23:59:59'))`,  if the filter is `date(col1) = '2023-03-22'`, this partition will be pruned, we can do this prune because we know that the date always is `2023-03-21`. you can implement the visit method in FoldConstantRuleOnFE and OneRangePartitionEvaluator to support this functions.

### How can we do it so finely ?
Generally, the range partition can separate to three parts: `const`, `range`, `other`.
for example,  the partition `[(1, 'a', 'D'), ('1', 'c', 'D'))` exist
1. first partition column is `const`: always equals to '1'
2. second partition column is `range`: `slot >= 'a' and <= 'c'`. If not later slot, it must be `slot >= 'a' and < 'c'`
3. third partition column is `other`: regardless of whether the upper and lower bounds are the same, it must exist multi values, e.g. `('1', 'a', 'D')`, `('1', 'a', 'F')`, `('1', 'b', 'A')`, `('1', 'c', 'A')` 

In a partition, there is one and only one `range` slot can exist; maybe zero or one or many `const`/`other` slots.
Normally, a partition look like [const*, range, other*], these are the possible shapes:
1. [range], e.g `[('1'), ('10'))`
2. [const, range], e.g. `[('1', 'a'), ('1', 'd'))`
3. [range, other, other], e.g. `[('1', '1', '1'), ('2', '1', '1'))`
4. [const, const, ..., range, other, other, ...], e.g. `[('1', '1', '2', '3', '4'), ('1', '1', '3', '3', '4'))`

The properties of `const`: 
1. we can replace slot to literal to evaluate expression tree.

The properties of `range`:
1. if the slot date type is discrete type, like int, and date, we can expand it to literal and evaluate expression tree
2. if not discrete type, like datetime, or the discrete values too much, like [1, 1000000), we can keep the slot in the expression tree, and assign a range for it, when evaluate expression tree, we also compute the range and check whether range is empty set, if so we can simplify to BooleanLiteral.FALSE to skip this partition.
5. if the range slot satisfied some conditions , we can fold the slot with some function too, see the datetime example above

The properties of `other`:
1. only when the previous slot is literal and equals to the lower bound or upper bound of partition, we can shrink the range of the `other` slot

According this properties, we can do it finely.


at the runtime, the `range` and `other` slot maybe shrink the range of values,
e.g.
1. the partition `[('a'), ('b'))` with predicate `part_col = 'a'` will shrink range `['a', 'b')` to `['a']`, like a `range` slot change/downgrading to `const` slot;
2. the partition `[('a', '1'), ('b', '10'))` with predicate `part_col1 = 'a'` will shrink the range of `other` slot from unknown(all range) to `['1', +∞)`, like a `other` slot change/downgrading to `range` slot.

But to simplify, I haven't change the type at the runtime, just shrink the ColumnRange.
This commit is contained in:
924060929
2023-03-24 09:06:52 +08:00
committed by GitHub
parent d3e7f12ada
commit 321bb3e9ee
25 changed files with 2320 additions and 315 deletions

View File

@ -161,6 +161,10 @@ public class PartitionKey implements Comparable<PartitionKey>, Writable {
return keys;
}
public List<PrimitiveType> getTypes() {
return types;
}
public long getHashValue() {
CRC32 hashValue = new CRC32();
int i = 0;

View File

@ -50,6 +50,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import com.google.common.collect.ImmutableList;
@ -62,6 +63,7 @@ import java.util.Optional;
import java.util.Set;
import java.util.Stack;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import javax.annotation.Nullable;
/**
@ -280,6 +282,23 @@ public class CascadesContext implements ScheduleContext, PlanSource {
this.outerScope = Optional.ofNullable(outerScope);
}
/** getAndCacheSessionVariable */
public <T> T getAndCacheSessionVariable(String cacheName,
T defaultValue, Function<SessionVariable, T> variableSupplier) {
ConnectContext connectContext = getConnectContext();
if (connectContext == null) {
return defaultValue;
}
StatementContext statementContext = getStatementContext();
if (statementContext == null) {
return defaultValue;
}
T cacheResult = statementContext.getOrRegisterCache(cacheName,
() -> variableSupplier.apply(connectContext.getSessionVariable()));
return cacheResult;
}
private CascadesContext execute(Job job) {
pushJob(job);
jobScheduler.executeJobPool(this);

View File

@ -17,8 +17,6 @@
package org.apache.doris.nereids.jobs;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.memo.CopyInResult;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
@ -34,7 +32,6 @@ import org.apache.doris.nereids.metrics.event.TransformEvent;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import com.google.common.base.Preconditions;
@ -45,7 +42,6 @@ import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
/**
* Abstract class for all job using for analyze and optimize query plan in Nereids.
@ -149,29 +145,12 @@ public abstract class Job implements TracerSupplier {
}
public static Set<String> getDisableRules(JobContext context) {
return getAndCacheSessionVariable(context, "disableNereidsRules",
ImmutableSet.of(), SessionVariable::getDisableNereidsRules);
return context.getCascadesContext().getAndCacheSessionVariable(
"disableNereidsRules", ImmutableSet.of(), SessionVariable::getDisableNereidsRules);
}
public static boolean isTraceEnable(JobContext context) {
return getAndCacheSessionVariable(context, "isTraceEnable",
false, SessionVariable::isEnableNereidsTrace);
}
private static <T> T getAndCacheSessionVariable(JobContext context, String cacheName,
T defaultValue, Function<SessionVariable, T> variableSupplier) {
CascadesContext cascadesContext = context.getCascadesContext();
ConnectContext connectContext = cascadesContext.getConnectContext();
if (connectContext == null) {
return defaultValue;
}
StatementContext statementContext = cascadesContext.getStatementContext();
if (statementContext == null) {
return defaultValue;
}
T cacheResult = statementContext.getOrRegisterCache(cacheName,
() -> variableSupplier.apply(connectContext.getSessionVariable()));
return cacheResult;
return context.getCascadesContext().getAndCacheSessionVariable(
"isTraceEnable", false, SessionVariable::isEnableNereidsTrace);
}
}

View File

@ -1719,7 +1719,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
private String parseTVFPropertyItem(TvfPropertyItemContext item) {
if (item.constant() != null) {
Object constant = visit(item.constant());
if (constant instanceof Literal && ((Literal) constant).isStringLiteral()) {
if (constant instanceof Literal && ((Literal) constant).isStringLikeLiteral()) {
return ((Literal) constant).getStringValue();
}
}

View File

@ -0,0 +1,113 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.catalog.PartitionKey;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import com.google.common.base.MoreObjects;
import com.google.common.base.Objects;
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Range;
/** ColumnBound */
public class ColumnBound implements Comparable<ColumnBound> {
private final Literal value;
private ColumnBound(Literal value) {
this.value = value;
}
@Override
public int compareTo(ColumnBound o) {
return value.toLegacyLiteral().compareTo(o.value.toLegacyLiteral());
}
public static ColumnBound of(Literal expr) {
return new ColumnBound(expr);
}
public Literal getValue() {
return value;
}
// <
public static Range<ColumnBound> lessThen(Literal value) {
return Range.lessThan(new ColumnBound(value));
}
// <=
public static Range<ColumnBound> atMost(Literal value) {
return Range.atMost(new ColumnBound(value));
}
// >
public static Range<ColumnBound> greaterThan(Literal value) {
return Range.greaterThan(new ColumnBound(value));
}
// >=
public static Range<ColumnBound> atLeast(Literal value) {
return Range.atLeast(new ColumnBound(value));
}
public static Range<ColumnBound> all() {
return Range.all();
}
public static ColumnRange empty() {
return ColumnRange.empty();
}
public static Range<ColumnBound> singleton(Literal value) {
return Range.singleton(new ColumnBound(value));
}
public static Range<ColumnBound> between(Literal lower, Literal upper) {
return Range.range(new ColumnBound(lower), BoundType.CLOSED, new ColumnBound(upper), BoundType.CLOSED);
}
public static Range<ColumnBound> range(Literal lower, BoundType lowerType, Literal upper, BoundType upperType) {
return Range.range(new ColumnBound(lower), lowerType, new ColumnBound(upper), upperType);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("value", PartitionKey.toString(ImmutableList.of(value.toLegacyLiteral())))
.toString();
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ColumnBound that = (ColumnBound) o;
return Objects.equal(value, that.value);
}
@Override
public int hashCode() {
return Objects.hashCode(value);
}
}

View File

@ -0,0 +1,141 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableRangeSet;
import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.TreeRangeSet;
import java.util.Objects;
import java.util.Set;
/** ColumnRange */
public class ColumnRange {
private final RangeSet<ColumnBound> rangeSet;
public ColumnRange() {
rangeSet = ImmutableRangeSet.of();
}
public ColumnRange(Range<ColumnBound> range) {
this.rangeSet = ImmutableRangeSet.of(range);
}
public ColumnRange(RangeSet<ColumnBound> rangeSet) {
this.rangeSet = Objects.requireNonNull(rangeSet);
}
public ColumnRange intersect(ColumnRange range) {
RangeSet<ColumnBound> newSet = TreeRangeSet.create();
range.rangeSet.asRanges().forEach(r -> newSet.addAll(rangeSet.subRangeSet(r)));
return new ColumnRange(newSet);
}
public ColumnRange union(ColumnRange range) {
RangeSet<ColumnBound> newSet = TreeRangeSet.create();
newSet.addAll(this.rangeSet);
newSet.addAll(range.rangeSet);
return new ColumnRange(newSet);
}
public Set<Range<ColumnBound>> asRanges() {
return rangeSet.asRanges();
}
public ColumnRange complete() {
return new ColumnRange(rangeSet.complement());
}
public boolean isEmptyRange() {
return rangeSet.isEmpty();
}
/** isSingleton */
public boolean isSingleton() {
Set<Range<ColumnBound>> ranges = rangeSet.asRanges();
if (ranges.size() != 1) {
return false;
}
Range<ColumnBound> range = ranges.iterator().next();
if (!range.hasLowerBound() || !range.hasUpperBound()) {
return false;
}
return range.lowerEndpoint().equals(range.upperEndpoint());
}
public Range<ColumnBound> span() {
return rangeSet.span();
}
public ColumnBound getLowerBound() {
return rangeSet.span().lowerEndpoint();
}
public ColumnBound getUpperBound() {
return rangeSet.span().upperEndpoint();
}
@Override
public String toString() {
return rangeSet.toString();
}
// <
public static ColumnRange lessThen(Literal value) {
return new ColumnRange(ColumnBound.lessThen(value));
}
// <=
public static ColumnRange atMost(Literal value) {
return new ColumnRange(ColumnBound.atMost(value));
}
// >
public static ColumnRange greaterThan(Literal value) {
return new ColumnRange(ColumnBound.greaterThan(value));
}
// >=
public static ColumnRange atLeast(Literal value) {
return new ColumnRange(ColumnBound.atLeast(value));
}
public static ColumnRange all() {
return new ColumnRange(ColumnBound.all());
}
public static ColumnRange empty() {
return new ColumnRange();
}
public static ColumnRange singleton(Literal value) {
return new ColumnRange(ColumnBound.singleton(value));
}
public static ColumnRange between(Literal lower, Literal upper) {
return new ColumnRange(ColumnBound.between(lower, upper));
}
public static ColumnRange range(Literal lower, BoundType lowerType, Literal upper, BoundType upperType) {
return new ColumnRange(ColumnBound.range(lower, lowerType, upper, upperType));
}
}

View File

@ -49,16 +49,22 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId;
import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentUser;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Database;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Date;
import org.apache.doris.nereids.trees.expressions.functions.scalar.User;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Version;
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.GlobalVariable;
@ -371,6 +377,26 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
return new ArrayLiteral(arguments);
}
@Override
public Expression visitDate(Date date, ExpressionRewriteContext context) {
if (!allArgsIsAllLiteral(date)) {
return date;
}
Literal child = (Literal) date.child();
if (child instanceof NullLiteral) {
return new NullLiteral(date.getDataType());
}
DataType dataType = child.getDataType();
if (dataType.isDateTimeType()) {
DateTimeLiteral dateTimeLiteral = (DateTimeLiteral) child;
return new DateLiteral(dateTimeLiteral.getYear(), dateTimeLiteral.getMonth(), dateTimeLiteral.getDay());
} else if (dataType.isDateTimeV2Type()) {
DateTimeV2Literal dateTimeLiteral = (DateTimeV2Literal) child;
return new DateV2Literal(dateTimeLiteral.getYear(), dateTimeLiteral.getMonth(), dateTimeLiteral.getDay());
}
return date;
}
@Override
public Expression visitVersion(Version version, ExpressionRewriteContext context) {
return new StringLiteral(GlobalVariable.version);

View File

@ -0,0 +1,101 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.catalog.ListPartitionItem;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.IntStream;
/** OneListPartitionInputs */
public class OneListPartitionEvaluator
extends DefaultExpressionRewriter<Map<Slot, PartitionSlotInput>> implements OnePartitionEvaluator {
private final long partitionId;
private final List<Slot> partitionSlots;
private final ListPartitionItem partitionItem;
private final ExpressionRewriteContext expressionRewriteContext;
public OneListPartitionEvaluator(long partitionId, List<Slot> partitionSlots,
ListPartitionItem partitionItem, CascadesContext cascadesContext) {
this.partitionId = partitionId;
this.partitionSlots = Objects.requireNonNull(partitionSlots, "partitionSlots cannot be null");
this.partitionItem = Objects.requireNonNull(partitionItem, "partitionItem cannot be null");
this.expressionRewriteContext = new ExpressionRewriteContext(
Objects.requireNonNull(cascadesContext, "cascadesContext cannot be null"));
}
@Override
public long getPartitionId() {
return partitionId;
}
@Override
public List<Map<Slot, PartitionSlotInput>> getOnePartitionInputs() {
return partitionItem.getItems().stream()
.map(keys -> {
List<Literal> literals = keys.getKeys()
.stream()
.map(literal -> Literal.fromLegacyLiteral(literal, literal.getType()))
.collect(ImmutableList.toImmutableList());
return IntStream.range(0, partitionSlots.size())
.mapToObj(index -> {
Slot partitionSlot = partitionSlots.get(index);
// partitionSlot will be replaced to this literal
Literal literal = literals.get(index);
// list partition don't need to compute the slot's range,
// so we pass through an empty map
return Pair.of(partitionSlot, new PartitionSlotInput(literal, ImmutableMap.of()));
}).collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value));
}).collect(ImmutableList.toImmutableList());
}
@Override
public Expression visit(Expression expr, Map<Slot, PartitionSlotInput> context) {
expr = super.visit(expr, context);
if (!(expr instanceof Literal)) {
// just forward to fold constant rule
return expr.accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext);
}
return expr;
}
@Override
public Expression visitSlot(Slot slot, Map<Slot, PartitionSlotInput> context) {
// replace partition slot to literal
PartitionSlotInput partitionSlotInput = context.get(slot);
return partitionSlotInput == null ? slot : partitionSlotInput.result;
}
@Override
public Expression evaluate(Expression expression, Map<Slot, PartitionSlotInput> currentInputs) {
return expression.accept(this, currentInputs);
}
}

View File

@ -0,0 +1,48 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import java.util.List;
import java.util.Map;
/** the evaluator of the partition which represent one partition */
public interface OnePartitionEvaluator {
long getPartitionId();
/**
* return a slot to expression mapping to replace the input.
* for example, list partition [('1', 'a'), ('10', 'd')) with 2 column part_col1 and part_col2
* will return a map: [{part_col1: '1', part_col2: 'a'}, {part_col1: '10', part_col2: 'd'}],
* if any mapping replace slot and evaluate in the PartitionPredicateEvaluator return an
* expression which not equals to BooleanLiteral.FALSE, we will scan the partition and skip
* subsequent mapping to evaluate.
*/
List<Map<Slot, PartitionSlotInput>> getOnePartitionInputs();
/**
* process children context and return current expression's context.
* for example, range partition [('1', 'a'), ('10', 'd')) with 2 column part_col1 and part_col2,
* if the child context contains `part_col1 = '1'`, then we will return a context which record
* the constraint: `part_col2 >= 'a'`, further more, if both exist `part_col2 < 'a'`,
* we will return a context which result expression is BooleanLiteral.FALSE
*/
Expression evaluate(Expression expression, Map<Slot, PartitionSlotInput> currentInputs);
}

View File

@ -0,0 +1,668 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.catalog.PartitionKey;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.RangePartitionItem;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rewrite.rules.OneRangePartitionEvaluator.EvaluateRangeInput;
import org.apache.doris.nereids.rules.expression.rewrite.rules.OneRangePartitionEvaluator.EvaluateRangeResult;
import org.apache.doris.nereids.rules.expression.rewrite.rules.PartitionRangeExpander.PartitionSlotType;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Date;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Range;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
/**
* OneRangePartitionEvaluator.
*
* you can see the process steps in the comment of PartitionSlotInput.columnRanges
*/
public class OneRangePartitionEvaluator
extends ExpressionVisitor<EvaluateRangeResult, EvaluateRangeInput>
implements OnePartitionEvaluator {
private final long partitionId;
private List<Slot> partitionSlots;
private RangePartitionItem partitionItem;
private ExpressionRewriteContext expressionRewriteContext;
private List<PartitionSlotType> partitionSlotTypes;
private List<Literal> lowers;
private List<Literal> uppers;
private List<List<Expression>> inputs;
private Map<Slot, Boolean> partitionSlotContainsNull;
private Map<Slot, PartitionSlotType> slotToType;
/** OneRangePartitionEvaluator */
public OneRangePartitionEvaluator(long partitionId, List<Slot> partitionSlots,
RangePartitionItem partitionItem, CascadesContext cascadesContext) {
this.partitionId = partitionId;
this.partitionSlots = Objects.requireNonNull(partitionSlots, "partitionSlots cannot be null");
this.partitionItem = Objects.requireNonNull(partitionItem, "partitionItem cannot be null");
this.expressionRewriteContext = new ExpressionRewriteContext(
Objects.requireNonNull(cascadesContext, "cascadesContext cannot be null"));
Range<PartitionKey> range = partitionItem.getItems();
this.lowers = toNereidsLiterals(range.lowerEndpoint());
this.uppers = toNereidsLiterals(range.upperEndpoint());
PartitionRangeExpander expander = new PartitionRangeExpander();
this.partitionSlotTypes = expander.computePartitionSlotTypes(lowers, uppers);
this.slotToType = IntStream.range(0, partitionSlots.size())
.mapToObj(index -> Pair.of(partitionSlots.get(index), partitionSlotTypes.get(index)))
.collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value));
this.partitionSlotContainsNull = IntStream.range(0, partitionSlots.size())
.mapToObj(index -> {
Slot slot = partitionSlots.get(index);
if (!slot.nullable()) {
return Pair.of(slot, false);
}
PartitionSlotType partitionSlotType = partitionSlotTypes.get(index);
boolean maybeNull = false;
switch (partitionSlotType) {
case CONST:
case RANGE:
maybeNull = range.lowerEndpoint().getKeys().get(index).isMinValue();
break;
case OTHER:
maybeNull = true;
break;
default:
throw new AnalysisException("Unknown partition slot type: " + partitionSlotType);
}
return Pair.of(slot, maybeNull);
}).collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value));
int expandThreshold = cascadesContext.getAndCacheSessionVariable(
"partitionPruningExpandThreshold",
10, sessionVariable -> sessionVariable.partitionPruningExpandThreshold);
List<List<Expression>> expandInputs = expander.tryExpandRange(
partitionSlots, lowers, uppers, partitionSlotTypes, expandThreshold);
// after expand range, we will get 2 dimension list like list:
// part_col1: [1], part_col2:[4, 5, 6], we should combine it to
// [1, 4], [1, 5], [1, 6] as inputs
this.inputs = Utils.allCombinations(expandInputs);
}
@Override
public long getPartitionId() {
return partitionId;
}
@Override
public List<Map<Slot, PartitionSlotInput>> getOnePartitionInputs() {
List<Map<Slot, PartitionSlotInput>> onePartitionInputs = Lists.newArrayList();
for (List<Expression> input : inputs) {
boolean previousIsLowerBoundLiteral = true;
boolean previousIsUpperBoundLiteral = true;
List<Pair<Slot, PartitionSlotInput>> slotToInputs = Lists.newArrayList();
for (int i = 0; i < partitionSlots.size(); ++i) {
Slot partitionSlot = partitionSlots.get(i);
// partitionSlot will be replaced to this expression
Expression expression = input.get(i);
ColumnRange slotRange = null;
PartitionSlotType partitionSlotType = partitionSlotTypes.get(i);
if (expression instanceof Literal) {
// const or expanded range
slotRange = ColumnRange.singleton((Literal) expression);
if (!expression.equals(lowers.get(i))) {
previousIsLowerBoundLiteral = false;
}
if (!expression.equals(uppers.get(i))) {
previousIsUpperBoundLiteral = false;
}
} else {
// un expanded range
switch (partitionSlotType) {
case RANGE:
boolean isLastPartitionColumn = i + 1 == partitionSlots.size();
BoundType rightBoundType = isLastPartitionColumn
? BoundType.OPEN : BoundType.CLOSED;
slotRange = ColumnRange.range(
lowers.get(i), BoundType.CLOSED, uppers.get(i), rightBoundType);
break;
case OTHER:
if (previousIsLowerBoundLiteral) {
slotRange = ColumnRange.atLeast(lowers.get(i));
} else if (previousIsUpperBoundLiteral) {
slotRange = ColumnRange.lessThen(uppers.get(i));
} else {
// unknown range
slotRange = ColumnRange.all();
}
break;
default:
throw new AnalysisException("Unknown partition slot type: " + partitionSlotType);
}
previousIsLowerBoundLiteral = false;
previousIsUpperBoundLiteral = false;
}
ImmutableMap<Slot, ColumnRange> slotToRange = ImmutableMap.of(partitionSlot, slotRange);
slotToInputs.add(Pair.of(partitionSlot, new PartitionSlotInput(expression, slotToRange)));
}
Map<Slot, PartitionSlotInput> slotPartitionSlotInputMap = fillSlotRangesToInputs(
slotToInputs.stream()
.collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value)));
onePartitionInputs.add(slotPartitionSlotInputMap);
}
return onePartitionInputs;
}
@Override
public Expression evaluate(Expression expression, Map<Slot, PartitionSlotInput> currentInputs) {
Map<Slot, ColumnRange> defaultColumnRanges = currentInputs.values().iterator().next().columnRanges;
EvaluateRangeResult result = expression.accept(
this, new EvaluateRangeInput(defaultColumnRanges, currentInputs));
return result.result;
}
@Override
public EvaluateRangeResult visit(Expression expr, EvaluateRangeInput context) {
EvaluateRangeResult result = evaluateChildrenThenThis(expr, context);
// NOTE: if children exist empty range return false
// !!! this is different from `returnFalseIfExistEmptyRange` !!!
expr = result.result;
if (expr.getDataType() instanceof BooleanType && !(expr instanceof Literal)
&& result.childrenResult.stream().anyMatch(childResult ->
childResult.columnRanges.values().stream().anyMatch(ColumnRange::isEmptyRange))) {
return new EvaluateRangeResult(BooleanLiteral.FALSE, result.columnRanges, result.childrenResult);
}
return result;
}
@Override
public EvaluateRangeResult visitSlot(Slot slot, EvaluateRangeInput context) {
// try to replace partition slot to literal
PartitionSlotInput slotResult = context.slotToInput.get(slot);
return slotResult == null
? new EvaluateRangeResult(slot, context.defaultColumnRanges, ImmutableList.of())
: new EvaluateRangeResult(slotResult.result, slotResult.columnRanges, ImmutableList.of());
}
@Override
public EvaluateRangeResult visitGreaterThan(GreaterThan greaterThan, EvaluateRangeInput context) {
EvaluateRangeResult result = evaluateChildrenThenThis(greaterThan, context);
if (!(result.result instanceof GreaterThan)) {
return result;
}
greaterThan = (GreaterThan) result.result;
if (greaterThan.left() instanceof Slot && greaterThan.right() instanceof Literal) {
Slot slot = (Slot) greaterThan.left();
if (isPartitionSlot(slot)) {
Map<Slot, ColumnRange> leftColumnRanges = result.childrenResult.get(0).columnRanges;
ColumnRange greaterThenRange = ColumnRange.greaterThan((Literal) greaterThan.right());
result = intersectSlotRange(result, leftColumnRanges, slot, greaterThenRange);
}
} else if (greaterThan.left() instanceof Literal && greaterThan.right() instanceof Slot) {
Slot slot = (Slot) greaterThan.right();
if (isPartitionSlot(slot)) {
Map<Slot, ColumnRange> rightColumnRanges = result.childrenResult.get(1).columnRanges;
ColumnRange lessThenRange = ColumnRange.lessThen((Literal) greaterThan.left());
result = intersectSlotRange(result, rightColumnRanges, slot, lessThenRange);
}
}
return result;
}
@Override
public EvaluateRangeResult visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, EvaluateRangeInput context) {
EvaluateRangeResult result = evaluateChildrenThenThis(greaterThanEqual, context);
if (!(result.result instanceof GreaterThanEqual)) {
return result;
}
greaterThanEqual = (GreaterThanEqual) result.result;
if (greaterThanEqual.left() instanceof Slot && greaterThanEqual.right() instanceof Literal) {
Slot slot = (Slot) greaterThanEqual.left();
if (isPartitionSlot(slot)) {
Map<Slot, ColumnRange> leftColumnRanges = result.childrenResult.get(0).columnRanges;
ColumnRange atLeastRange = ColumnRange.atLeast((Literal) greaterThanEqual.right());
result = intersectSlotRange(result, leftColumnRanges, slot, atLeastRange);
}
} else if (greaterThanEqual.left() instanceof Literal && greaterThanEqual.right() instanceof Slot) {
Slot slot = (Slot) greaterThanEqual.right();
if (isPartitionSlot(slot)) {
Map<Slot, ColumnRange> rightColumnRanges = result.childrenResult.get(1).columnRanges;
ColumnRange atMostRange = ColumnRange.atMost((Literal) greaterThanEqual.left());
result = intersectSlotRange(result, rightColumnRanges, slot, atMostRange);
}
}
return result;
}
@Override
public EvaluateRangeResult visitLessThan(LessThan lessThan, EvaluateRangeInput context) {
EvaluateRangeResult result = evaluateChildrenThenThis(lessThan, context);
if (!(result.result instanceof LessThan)) {
return result;
}
lessThan = (LessThan) result.result;
if (lessThan.left() instanceof Slot && lessThan.right() instanceof Literal) {
Slot slot = (Slot) lessThan.left();
if (isPartitionSlot(slot)) {
Map<Slot, ColumnRange> leftColumnRanges = result.childrenResult.get(0).columnRanges;
ColumnRange greaterThenRange = ColumnRange.lessThen((Literal) lessThan.right());
result = intersectSlotRange(result, leftColumnRanges, slot, greaterThenRange);
}
} else if (lessThan.left() instanceof Literal && lessThan.right() instanceof Slot) {
Slot slot = (Slot) lessThan.right();
if (isPartitionSlot(slot)) {
Map<Slot, ColumnRange> rightColumnRanges = result.childrenResult.get(1).columnRanges;
ColumnRange lessThenRange = ColumnRange.greaterThan((Literal) lessThan.left());
result = intersectSlotRange(result, rightColumnRanges, slot, lessThenRange);
}
}
return result;
}
@Override
public EvaluateRangeResult visitLessThanEqual(LessThanEqual lessThanEqual, EvaluateRangeInput context) {
EvaluateRangeResult result = evaluateChildrenThenThis(lessThanEqual, context);
if (!(result.result instanceof LessThanEqual)) {
return result;
}
lessThanEqual = (LessThanEqual) result.result;
if (lessThanEqual.left() instanceof Slot && lessThanEqual.right() instanceof Literal) {
Slot slot = (Slot) lessThanEqual.left();
if (isPartitionSlot(slot)) {
Map<Slot, ColumnRange> leftColumnRanges = result.childrenResult.get(0).columnRanges;
ColumnRange atLeastRange = ColumnRange.atMost((Literal) lessThanEqual.right());
result = intersectSlotRange(result, leftColumnRanges, slot, atLeastRange);
}
} else if (lessThanEqual.left() instanceof Literal && lessThanEqual.right() instanceof Slot) {
Slot slot = (Slot) lessThanEqual.right();
if (isPartitionSlot(slot)) {
Map<Slot, ColumnRange> rightColumnRanges = result.childrenResult.get(1).columnRanges;
ColumnRange atMostRange = ColumnRange.atLeast((Literal) lessThanEqual.left());
result = intersectSlotRange(result, rightColumnRanges, slot, atMostRange);
}
}
return result;
}
@Override
public EvaluateRangeResult visitEqualTo(EqualTo equalTo, EvaluateRangeInput context) {
EvaluateRangeResult result = evaluateChildrenThenThis(equalTo, context);
if (!(result.result instanceof EqualTo)) {
return result;
}
equalTo = (EqualTo) result.result;
if (equalTo.left() instanceof Slot && equalTo.right() instanceof Literal) {
Slot slot = (Slot) equalTo.left();
if (isPartitionSlot(slot)) {
Map<Slot, ColumnRange> leftColumnRanges = result.childrenResult.get(0).columnRanges;
ColumnRange atLeastRange = ColumnRange.singleton((Literal) equalTo.right());
result = intersectSlotRange(result, leftColumnRanges, slot, atLeastRange);
}
} else if (equalTo.left() instanceof Literal && equalTo.right() instanceof Slot) {
Slot slot = (Slot) equalTo.right();
if (isPartitionSlot(slot)) {
Map<Slot, ColumnRange> rightColumnRanges = result.childrenResult.get(1).columnRanges;
ColumnRange atMostRange = ColumnRange.singleton((Literal) equalTo.left());
result = intersectSlotRange(result, rightColumnRanges, slot, atMostRange);
}
}
return result;
}
@Override
public EvaluateRangeResult visitInPredicate(InPredicate inPredicate, EvaluateRangeInput context) {
EvaluateRangeResult result = evaluateChildrenThenThis(inPredicate, context);
if (!(result.result instanceof InPredicate)) {
return result;
}
inPredicate = (InPredicate) result.result;
if (inPredicate.getCompareExpr() instanceof Slot
&& inPredicate.getOptions().stream().allMatch(Literal.class::isInstance)) {
Slot slot = (Slot) inPredicate.getCompareExpr();
ColumnRange unionLiteralRange = inPredicate.getOptions()
.stream()
.map(Literal.class::cast)
.map(ColumnRange::singleton)
.reduce(ColumnRange.empty(), ColumnRange::union);
Map<Slot, ColumnRange> slotRanges = result.childrenResult.get(0).columnRanges;
result = intersectSlotRange(result, slotRanges, slot, unionLiteralRange);
}
return result;
}
@Override
public EvaluateRangeResult visitIsNull(IsNull isNull, EvaluateRangeInput context) {
EvaluateRangeResult result = evaluateChildrenThenThis(isNull, context);
if (!(result.result instanceof IsNull)) {
return result;
}
Expression child = isNull.child();
if (!(child instanceof Slot) || !isPartitionSlot((Slot) child)) {
return result;
}
if (!partitionSlotContainsNull.get((Slot) child)) {
return new EvaluateRangeResult(BooleanLiteral.FALSE, result.columnRanges, result.childrenResult);
}
return result;
}
@Override
public EvaluateRangeResult visitAnd(And and, EvaluateRangeInput context) {
EvaluateRangeResult result = evaluateChildrenThenThis(and, context);
result = mergeRanges(result.result, result.childrenResult.get(0), result.childrenResult.get(1),
(leftRange, rightRange) -> leftRange.intersect(rightRange));
result = returnFalseIfExistEmptyRange(result);
if (result.result.equals(BooleanLiteral.FALSE)) {
return result;
}
// shrink range and prune the other type: if previous column is literal and equals to the bound
result = determinateRangeOfOtherType(result, lowers, true);
result = determinateRangeOfOtherType(result, uppers, false);
return result;
}
@Override
public EvaluateRangeResult visitOr(Or or, EvaluateRangeInput context) {
EvaluateRangeResult result = evaluateChildrenThenThis(or, context);
result = mergeRanges(result.result, result.childrenResult.get(0), result.childrenResult.get(1),
(leftRange, rightRange) -> leftRange.union(rightRange));
return returnFalseIfExistEmptyRange(result);
}
@Override
public EvaluateRangeResult visitNot(Not not, EvaluateRangeInput context) {
EvaluateRangeResult result = evaluateChildrenThenThis(not, context);
Map<Slot, ColumnRange> newRanges = result.childrenResult.get(0).columnRanges.entrySet()
.stream()
.map(slotToRange -> Pair.of(slotToRange.getKey(), slotToRange.getValue().complete()))
.collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value));
result = new EvaluateRangeResult(result.result, newRanges, result.childrenResult);
return returnFalseIfExistEmptyRange(result);
}
private EvaluateRangeResult evaluateChildrenThenThis(Expression expr, EvaluateRangeInput context) {
// evaluate children
List<Expression> newChildren = new ArrayList<>();
List<EvaluateRangeResult> childrenResults = new ArrayList<>();
boolean hasNewChildren = false;
for (Expression child : expr.children()) {
EvaluateRangeResult childResult = child.accept(this, context);
if (childResult.result != child) {
hasNewChildren = true;
}
childrenResults.add(childResult);
newChildren.add(childResult.result);
}
if (hasNewChildren) {
expr = expr.withChildren(newChildren);
}
// evaluate this
expr = expr.accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext);
return new EvaluateRangeResult(expr, context.defaultColumnRanges, childrenResults);
}
private EvaluateRangeResult returnFalseIfExistEmptyRange(EvaluateRangeResult result) {
Expression expr = result.result;
if (expr.getDataType() instanceof BooleanType && !(expr instanceof Literal)
&& result.columnRanges.values().stream().anyMatch(ColumnRange::isEmptyRange)) {
return new EvaluateRangeResult(BooleanLiteral.FALSE, result.columnRanges, result.childrenResult);
}
return result;
}
private EvaluateRangeResult intersectSlotRange(EvaluateRangeResult originResult,
Map<Slot, ColumnRange> columnRanges, Slot slot, ColumnRange otherRange) {
ColumnRange columnRange = columnRanges.get(slot);
ColumnRange intersect = columnRange.intersect(otherRange);
Map<Slot, ColumnRange> newColumnRanges = replaceSlotRange(columnRanges, slot, intersect);
if (intersect.isEmptyRange()) {
return new EvaluateRangeResult(BooleanLiteral.FALSE, newColumnRanges, originResult.childrenResult);
} else {
return new EvaluateRangeResult(originResult.result, newColumnRanges, originResult.childrenResult);
}
}
private EvaluateRangeResult determinateRangeOfOtherType(
EvaluateRangeResult context, List<Literal> partitionBound, boolean isLowerBound) {
if (context.result instanceof Literal) {
return context;
}
Slot qualifiedSlot = null;
ColumnRange qualifiedRange = null;
for (int i = 0; i < partitionSlotTypes.size(); i++) {
PartitionSlotType partitionSlotType = partitionSlotTypes.get(i);
Slot slot = partitionSlots.get(i);
switch (partitionSlotType) {
case CONST: continue;
case RANGE:
ColumnRange columnRange = context.columnRanges.get(slot);
if (!columnRange.isSingleton()
|| !columnRange.getLowerBound().getValue().equals(partitionBound.get(i))) {
return context;
}
continue;
case OTHER:
columnRange = context.columnRanges.get(slot);
if (columnRange.isSingleton()
&& columnRange.getLowerBound().getValue().equals(partitionBound.get(i))) {
continue;
}
qualifiedSlot = slot;
if (isLowerBound) {
qualifiedRange = ColumnRange.atLeast(partitionBound.get(i));
} else {
qualifiedRange = i + 1 == partitionSlots.size()
? ColumnRange.lessThen(partitionBound.get(i))
: ColumnRange.atMost(partitionBound.get(i));
}
break;
default:
throw new AnalysisException("Unknown partition slot type: " + partitionSlotType);
}
}
if (qualifiedSlot != null) {
ColumnRange origin = context.columnRanges.get(qualifiedSlot);
ColumnRange newRange = origin.intersect(qualifiedRange);
Map<Slot, ColumnRange> newRanges = replaceSlotRange(context.columnRanges, qualifiedSlot, newRange);
if (newRange.isEmptyRange()) {
return new EvaluateRangeResult(BooleanLiteral.FALSE, newRanges, context.childrenResult);
} else {
return new EvaluateRangeResult(context.result, newRanges, context.childrenResult);
}
}
return context;
}
private Map<Slot, ColumnRange> replaceSlotRange(Map<Slot, ColumnRange> originRange, Slot slot, ColumnRange range) {
LinkedHashMap<Slot, ColumnRange> newRanges = Maps.newLinkedHashMap(originRange);
newRanges.put(slot, range);
return ImmutableMap.copyOf(newRanges);
}
private EvaluateRangeResult mergeRanges(
Expression originResult, EvaluateRangeResult left, EvaluateRangeResult right,
BiFunction<ColumnRange, ColumnRange, ColumnRange> mergeFunction) {
Map<Slot, ColumnRange> leftRanges = left.columnRanges;
Map<Slot, ColumnRange> rightRanges = right.columnRanges;
Set<Slot> slots = ImmutableSet.<Slot>builder()
.addAll(leftRanges.keySet())
.addAll(rightRanges.keySet())
.build();
Map<Slot, ColumnRange> mergedRange = slots.stream()
.map(slot -> Pair.of(slot, mergeFunction.apply(leftRanges.get(slot), rightRanges.get(slot))))
.collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value));
return new EvaluateRangeResult(originResult, mergedRange, ImmutableList.of(left, right));
}
private List<Literal> toNereidsLiterals(PartitionKey partitionKey) {
return IntStream.range(0, partitionKey.getKeys().size())
.mapToObj(index -> {
LiteralExpr literalExpr = partitionKey.getKeys().get(index);
PrimitiveType primitiveType = partitionKey.getTypes().get(index);
Type type = Type.fromPrimitiveType(primitiveType);
return Literal.fromLegacyLiteral(literalExpr, type);
}).collect(ImmutableList.toImmutableList());
}
@Override
public EvaluateRangeResult visitDate(Date date, EvaluateRangeInput context) {
EvaluateRangeResult result = super.visitDate(date, context);
if (!(result.result instanceof Date)) {
return result;
}
date = (Date) result.result;
if (!(date.child() instanceof Slot) || !isPartitionSlot((Slot) date.child())) {
return result;
}
Slot partitionSlot = (Slot) date.child();
PartitionSlotType partitionSlotType = getPartitionSlotType(partitionSlot).get();
if (partitionSlotType != PartitionSlotType.RANGE || partitionSlotContainsNull.get(partitionSlot)) {
return result;
}
DataType childType = date.child().getDataType();
if (!childType.isDateTimeType() && !childType.isDateTimeV2Type()) {
return result;
}
ColumnRange dateTimeRange = result.childrenResult.get(0).columnRanges.get((Slot) date.child());
if (dateTimeRange.isEmptyRange()) {
return result;
}
Range<ColumnBound> span = dateTimeRange.span();
Literal lower = span.lowerEndpoint().getValue();
Literal upper = span.upperEndpoint().getValue();
Expression lowerDate = new Date(lower).accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext);
Expression upperDate = new Date(upper).accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext);
if (lowerDate instanceof Literal && upperDate instanceof Literal && lowerDate.equals(upperDate)) {
return new EvaluateRangeResult(lowerDate, result.columnRanges, result.childrenResult);
}
return result;
}
private boolean isPartitionSlot(Slot slot) {
return slotToType.containsKey(slot);
}
private Optional<PartitionSlotType> getPartitionSlotType(Slot slot) {
return Optional.ofNullable(slotToType.get(slot));
}
private Map<Slot, PartitionSlotInput> fillSlotRangesToInputs(
Map<Slot, PartitionSlotInput> inputs) {
Map<Slot, ColumnRange> allColumnRanges = inputs.entrySet()
.stream()
.map(entry -> Pair.of(entry.getKey(), entry.getValue().columnRanges.get(entry.getKey())))
.collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value));
return inputs.keySet()
.stream()
.map(slot -> Pair.of(slot, new PartitionSlotInput(inputs.get(slot).result, allColumnRanges)))
.collect(ImmutableMap.toImmutableMap(Pair::key, Pair::value));
}
/** EvaluateRangeInput */
public static class EvaluateRangeInput {
private Map<Slot, ColumnRange> defaultColumnRanges;
private Map<Slot, PartitionSlotInput> slotToInput;
public EvaluateRangeInput(Map<Slot, ColumnRange> defaultColumnRanges,
Map<Slot, PartitionSlotInput> slotToInput) {
this.defaultColumnRanges = defaultColumnRanges;
this.slotToInput = slotToInput;
}
}
/**
* EvaluateRangeResult.
*
* bind expression and ColumnRange, so we can not only compute expression tree, but also compute range.
* if column range is empty range, the predicate should return BooleanLiteral.FALSE, means this partition
* can be pruned.
*/
public static class EvaluateRangeResult {
private final Expression result;
private final Map<Slot, ColumnRange> columnRanges;
private final List<EvaluateRangeResult> childrenResult;
public EvaluateRangeResult(Expression result, Map<Slot, ColumnRange> columnRanges,
List<EvaluateRangeResult> childrenResult) {
this.result = result;
this.columnRanges = columnRanges;
this.childrenResult = childrenResult;
}
}
}

View File

@ -0,0 +1,94 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.catalog.ListPartitionItem;
import org.apache.doris.catalog.PartitionInfo;
import org.apache.doris.catalog.PartitionItem;
import org.apache.doris.catalog.RangePartitionItem;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
/** PartitionPruner */
public class PartitionPruner {
private List<OnePartitionEvaluator> partitions;
private Expression partitionPredicate;
private PartitionPruner(List<OnePartitionEvaluator> partitions, Expression partitionPredicate) {
this.partitions = Objects.requireNonNull(partitions, "partitions cannot be null");
this.partitionPredicate = Objects.requireNonNull(partitionPredicate, "partitionPredicate cannot be null");
}
public List<Long> prune() {
return partitions.stream()
.filter(partitionEvaluator -> !canPrune(partitionEvaluator))
.map(OnePartitionEvaluator::getPartitionId)
.collect(ImmutableList.toImmutableList());
}
/** prune partition */
public static List<Long> prune(List<Slot> partitionSlots, Expression partitionPredicate,
PartitionInfo partitionInfo, CascadesContext cascadesContext) {
partitionPredicate = TryEliminateUninterestedPredicates.rewrite(
partitionPredicate, ImmutableSet.copyOf(partitionSlots), cascadesContext);
Map<Long, PartitionItem> idToPartitions = partitionInfo.getIdToItem(false);
List<OnePartitionEvaluator> evaluators = idToPartitions.entrySet()
.stream()
.map(kv -> toPartitionEvaluator(kv.getKey(), kv.getValue(), partitionSlots, cascadesContext))
.collect(ImmutableList.toImmutableList());
PartitionPruner partitionPruner = new PartitionPruner(evaluators, partitionPredicate);
return partitionPruner.prune();
}
/** convert partition item to partition evaluator */
public static final OnePartitionEvaluator toPartitionEvaluator(long id, PartitionItem partitionItem,
List<Slot> partitionSlots, CascadesContext cascadesContext) {
if (partitionItem instanceof ListPartitionItem) {
return new OneListPartitionEvaluator(
id, partitionSlots, (ListPartitionItem) partitionItem, cascadesContext);
} else if (partitionItem instanceof RangePartitionItem) {
return new OneRangePartitionEvaluator(
id, partitionSlots, (RangePartitionItem) partitionItem, cascadesContext);
} else {
return new UnknownPartitionEvaluator(id, partitionItem);
}
}
private boolean canPrune(OnePartitionEvaluator evaluator) {
List<Map<Slot, PartitionSlotInput>> onePartitionInputs = evaluator.getOnePartitionInputs();
for (Map<Slot, PartitionSlotInput> currentInputs : onePartitionInputs) {
Expression result = evaluator.evaluate(partitionPredicate, currentInputs);
if (!result.equals(BooleanLiteral.FALSE)) {
return false;
}
}
return true;
}
}

View File

@ -0,0 +1,285 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.DataType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import org.apache.commons.lang.time.DateFormatUtils;
import org.apache.commons.lang3.time.DateUtils;
import java.math.BigInteger;
import java.text.ParseException;
import java.time.temporal.ChronoUnit;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.function.Function;
/**
* PartitionRangeExpander
*
* try to expand range partition if the partition is enumerable
* example:
* partition column type: int, range [1, 10), return [1, 2, 3, 4, 5, 6, 7, 8, 9).
*
* after expand range, we can replace partition slot to the literal in expression tree and evaluate it.
*/
public class PartitionRangeExpander {
/** PartitionSlotType */
public enum PartitionSlotType {
// e.g. the first partition column is const '1' in partition [('1', '2', '5'), ('1', '3', '5')),
// we can substitute the slot in the expression tree and evaluate.
CONST,
// e.g. the second partition column is range ['2', '3'] in partition [('1', '2', '5'), ('1', '3', '5'))
// if the partition column is discrete type(int, date), we expand and iterate it and substitute the slot
// in the expression tree and evaluate, else use range set to check whether the partition is valid range
RANGE,
// e.g. the third partition column other type in partition [('1', '2', '5'), ('1', '3', '5')),
// every partition column after the first range type column is other type.
// we can check the range if the previous partition column equals to the bound, e.g.
// if first_part_column = '1' and second_part_column = '2', then we can infer third_part_column >= '5'.
// if first_part_column = '1' and second_part_column = '3', then we can infer third_part_column < '5'.
OTHER;
}
/** expandRangeLiterals */
public final List<List<Expression>> tryExpandRange(
List<Slot> partitionSlots, List<Literal> lowers, List<Literal> uppers,
List<PartitionSlotType> partitionSlotTypes, int expandThreshold) {
long expandedCount = 1;
List<List<Expression>> expandedLists = Lists.newArrayListWithCapacity(lowers.size());
for (int i = 0; i < partitionSlotTypes.size(); i++) {
Slot slot = partitionSlots.get(i);
PartitionSlotType partitionSlotType = partitionSlotTypes.get(i);
List<Expression> expandedList = Lists.newArrayList();
Literal lower = lowers.get(i);
switch (partitionSlotType) {
case CONST:
// don't need expanded, just replace to literal as input
expandedList.add(lower);
break;
case RANGE:
// try to expand range to literals as input
// e.g. [1, 5) will be expand to [1, 2, 3, 4] if the data type is integer like type.
// some types can not expend, like varchar type
Literal upper = uppers.get(i);
try {
boolean isLastColumn = i + 1 == partitionSlots.size();
if (canExpandRange(slot, lower, upper, expandedCount, expandThreshold)) {
expandedList.addAll(ImmutableList.copyOf(
enumerableIterator(slot, lower, upper, isLastColumn))
);
} else {
expandedList.add(slot);
}
} catch (Throwable t) {
// catch for safety, should not invoke here
expandedList.add(slot);
}
break;
case OTHER:
// can't expend other slots, keep slot as input
expandedList.add(slot);
break;
default:
throw new AnalysisException("Unknown partition slot type: " + partitionSlotType);
}
expandedCount *= expandedList.size();
expandedLists.add(expandedList);
}
return expandedLists;
}
private final boolean canExpandRange(Slot slot, Literal lower, Literal upper,
long expandedCount, int expandThreshold) {
DataType type = slot.getDataType();
if (!type.isIntegerLikeType() && !type.isDateType() && !type.isDateV2Type()) {
return false;
}
try {
long count = enumerableCount(slot.getDataType(), lower, upper);
if (count <= 0) {
return false;
}
// too much expanded will consuming resources of frontend,
// e.g. [1, 100000000), we should skip expand it
return (expandedCount * count) <= expandThreshold;
} catch (Throwable t) {
// e.g. max_value can not expand
return false;
}
}
/** the types will like this: [CONST, CONST, ..., RANGE, OTHER, OTHER, ...] */
public List<PartitionSlotType> computePartitionSlotTypes(List<Literal> lowers, List<Literal> uppers) {
PartitionSlotType previousType = PartitionSlotType.CONST;
List<PartitionSlotType> types = Lists.newArrayListWithCapacity(lowers.size());
for (int i = 0; i < lowers.size(); ++i) {
if (previousType == PartitionSlotType.RANGE || previousType == PartitionSlotType.OTHER) {
types.add(PartitionSlotType.OTHER);
continue;
}
Literal lower = lowers.get(i);
Literal upper = uppers.get(i);
PartitionSlotType type = lower.toLegacyLiteral().equals(upper.toLegacyLiteral())
? PartitionSlotType.CONST
: PartitionSlotType.RANGE;
types.add(type);
previousType = type;
}
return types;
}
private final long enumerableCount(DataType dataType, Literal startInclusive, Literal endExclusive) throws
ParseException {
if (dataType.isIntegerLikeType()) {
BigInteger start = new BigInteger(startInclusive.getStringValue());
BigInteger end = new BigInteger(endExclusive.getStringValue());
return end.subtract(start).longValue();
} else if (dataType.isDateType() || dataType.isDateV2Type()) {
Date start = DateUtils.parseDate(startInclusive.toString(), DateLiteral.JAVA_DATE_FORMAT);
Date end = DateUtils.parseDate(endExclusive.toString(), DateLiteral.JAVA_DATE_FORMAT);
return ChronoUnit.DAYS.between(start.toInstant(), end.toInstant());
}
// not enumerable
return -1;
}
private final Iterator<? extends Expression> enumerableIterator(
Slot slot, Literal startInclusive, Literal endLiteral, boolean endExclusive) throws ParseException {
DataType dataType = slot.getDataType();
if (dataType.isIntegerLikeType()) {
BigInteger start = new BigInteger(startInclusive.getStringValue());
BigInteger end = new BigInteger(endLiteral.getStringValue());
if (dataType.isTinyIntType()) {
return new IntegerLikeRangePartitionValueIterator<>(
start, end, endExclusive, value -> new TinyIntLiteral(value.byteValue()));
} else if (dataType.isSmallIntType()) {
return new IntegerLikeRangePartitionValueIterator<>(
start, end, endExclusive, value -> new SmallIntLiteral(value.shortValue()));
} else if (dataType.isIntegerType()) {
return new IntegerLikeRangePartitionValueIterator<>(
start, end, endExclusive, value -> new IntegerLiteral(value.intValue()));
} else if (dataType.isBigIntType()) {
return new IntegerLikeRangePartitionValueIterator<>(
start, end, endExclusive, value -> new BigIntLiteral(value.longValue()));
} else if (dataType.isLargeIntType()) {
return new IntegerLikeRangePartitionValueIterator<>(
start, end, endExclusive, LargeIntLiteral::new);
}
} else if (dataType.isDateType()) {
Date startDate = DateUtils.parseDate(startInclusive.toString(), DateLiteral.JAVA_DATE_FORMAT);
Date endDate = DateUtils.parseDate(endLiteral.toString(), DateLiteral.JAVA_DATE_FORMAT);
return new DateLikeRangePartitionValueIterator<>(startDate, endDate, endExclusive,
date -> new DateLiteral(DateFormatUtils.format(date, DateLiteral.JAVA_DATE_FORMAT)));
} else if (dataType.isDateV2Type()) {
Date startDate = DateUtils.parseDate(startInclusive.toString(), DateLiteral.JAVA_DATE_FORMAT);
Date endDate = DateUtils.parseDate(endLiteral.toString(), DateLiteral.JAVA_DATE_FORMAT);
return new DateLikeRangePartitionValueIterator<>(startDate, endDate, endExclusive,
date -> new DateV2Literal(DateFormatUtils.format(date, DateLiteral.JAVA_DATE_FORMAT)));
}
// unsupported type
return Iterators.singletonIterator(slot);
}
private class IntegerLikeRangePartitionValueIterator<L extends IntegerLikeLiteral>
extends RangePartitionValueIterator<BigInteger, L> {
public IntegerLikeRangePartitionValueIterator(BigInteger startInclusive, BigInteger end,
boolean endExclusive, Function<BigInteger, L> toLiteral) {
super(startInclusive, end, endExclusive, toLiteral);
}
@Override
protected BigInteger doGetNext(BigInteger current) {
return current.add(BigInteger.ONE);
}
}
private class DateLikeRangePartitionValueIterator<L extends Literal>
extends RangePartitionValueIterator<Date, L> {
public DateLikeRangePartitionValueIterator(
Date startInclusive, Date finish, boolean endExclusive, Function<Date, L> toLiteral) {
super(startInclusive, finish, endExclusive, toLiteral);
}
@Override
protected Date doGetNext(Date current) {
return DateUtils.addDays(current, 1);
}
}
private abstract class RangePartitionValueIterator<C extends Comparable, L extends Literal>
implements Iterator<L> {
private final C startInclusive;
private final C end;
private final boolean endExclusive;
private C current;
private final Function<C, L> toLiteral;
public RangePartitionValueIterator(C startInclusive, C end, boolean endExclusive, Function<C, L> toLiteral) {
this.startInclusive = startInclusive;
this.end = end;
this.endExclusive = endExclusive;
this.current = this.startInclusive;
this.toLiteral = toLiteral;
}
@Override
public boolean hasNext() {
if (endExclusive) {
return current.compareTo(end) < 0;
} else {
return current.compareTo(end) <= 0;
}
}
@Override
public L next() {
if (hasNext()) {
C value = current;
current = doGetNext(current);
return toLiteral.apply(value);
}
throw new NoSuchElementException();
}
protected abstract C doGetNext(C current);
}
}

View File

@ -0,0 +1,125 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import com.google.common.collect.ImmutableMap;
import java.util.Map;
/**
* PartitionSlotInput, the input of the partition slot.
* We will replace the partition slot to PartitionSlotInput#result in the partition predicate,
* so that we can evaluate the expression tree.
*
* for example, the partition predicate: `part_column1 > 1`, and exist a partition range is [('1'), ('4')),
* and part_column1 is int type.
*
* GreaterThen GreaterThen
* / \ -> / \
* Slot(part_column1) IntegerLiteral(1) IntegerLiteral(n) IntegerLiteral(1)
* | ^
* | |
* +------------------------------------------------------------
* |
* replace by
* PartitionSlotInput(result = IntegerLiteral(1))
* PartitionSlotInput(result = IntegerLiteral(2))
* PartitionSlotInput(result = IntegerLiteral(3))
*
*
* if the partition slot is not enumerable(some RANGE / all OTHER partition slot type), we will stay slot:
* PartitionSlotInput(result = Slot(part_column1))
*/
public class PartitionSlotInput {
// the partition slot will be replaced to this result
public final Expression result;
// all partition slot's range map, the example in the class comment, it will be `{Slot(part_column1): [1, 4)}`.
// this range will use as the initialized partition slot range, every expression has a related columnRange map.
// as the expression executes, the upper expression' columnRange map will be computed.
// for example, the predicate `part_column1 > 100 or part_column1 < 0`.
//
// the [1, 10000) is too much we default not expand to IntLiterals.
// this are the process steps:
//
// Or
// / \
// GreaterThen LessThen
// / \ / \
// part_column1 IntegerLiteral(100) part_column1 IntegerLiteral(0)
// (part_column1: [1,10000)) (part_column1: [1,10000)) (part_column1: [1,10000)) (part_column1: [1,10000))
//
// |
// v
//
// Or
// / \
// GreaterThen LessThen
// (part_column1: [1,10000) and (100, +∞)) (part_column1: [1,10000) and (-∞, 0))
// / \ / \
// part_column1 IntegerLiteral(100) part_column1 IntegerLiteral(0)
//
// |
// v
//
// Or
// / \
// GreaterThen LessThen
// (part_column1: (100,10000)) (part_column1: empty range)
// / \ / \
// part_column1 IntegerLiteral(100) part_column1 IntegerLiteral(0)
//
// |
// v
//
// Or
// / \
// GreaterThen BooleanLiteral.FALSE <- empty set to false
// (part_column1: (100,10000))
// / \
// part_column1 IntegerLiteral(100)
//
// |
// v
//
// Or
// (part_column1: (100,10000) or empty range)
// / \
// GreaterThen BooleanLiteral.FALSE
// / \
// part_column1 IntegerLiteral(100)
//
// |
// v
//
// GreaterThen <- fold `expr or false` to expr
// (part_column1: (100,10000)) <- merge columnRanges
// / \
// part_column1 IntegerLiteral(100)
//
// because we can't fold this predicate to BooleanLiteral.FALSE, so we should scan the partition.
public final Map<Slot, ColumnRange> columnRanges;
public PartitionSlotInput(Expression result, Map<Slot, ColumnRange> columnRanges) {
this.result = result;
this.columnRanges = ImmutableMap.copyOf(columnRanges);
}
}

View File

@ -51,33 +51,27 @@ public class SimplifyNotExprRule extends AbstractExpressionRewriteRule {
Expression child = not.child();
if (child instanceof ComparisonPredicate) {
ComparisonPredicate cp = (ComparisonPredicate) not.child();
Expression left = rewrite(cp.left(), context);
Expression right = rewrite(cp.right(), context);
Expression left = cp.left();
Expression right = cp.right();
// TODO: visit concrete class instead of `instanceof`.
if (child instanceof GreaterThan) {
return new LessThanEqual(left, right);
return new LessThanEqual(left, right).accept(this, context);
} else if (child instanceof GreaterThanEqual) {
return new LessThan(left, right);
return new LessThan(left, right).accept(this, context);
} else if (child instanceof LessThan) {
return new GreaterThanEqual(left, right);
return new GreaterThanEqual(left, right).accept(this, context);
} else if (child instanceof LessThanEqual) {
return new GreaterThan(left, right);
} else {
not.withChildren(child.withChildren(left, right));
return new GreaterThan(left, right).accept(this, context);
}
} else if (child instanceof CompoundPredicate) {
CompoundPredicate cp = (CompoundPredicate) not.child();
Expression left = rewrite(new Not(cp.left()), context);
Expression right = rewrite(new Not(cp.right()), context);
return cp.flip(left, right);
CompoundPredicate cp = (CompoundPredicate) child;
Not left = new Not(cp.left());
Not right = new Not(cp.right());
return cp.flip(left, right).accept(this, context);
} else if (child instanceof Not) {
return child.child(0).accept(this, context);
}
if (child instanceof Not) {
Not son = (Not) child;
return rewrite(son.child(), context);
}
return not;
return super.visitNot(not, context);
}
}

View File

@ -0,0 +1,120 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TryEliminateUninterestedPredicates.Context;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import java.util.Set;
/**
* TryEliminateUninterestedPredicates
*
* this rewriter usually used to extract the partition columns related predicates,
* and try to eliminate partition columns related predicate.
*
* e.g.
* (part = 1 and non_part = 'a') or (part = 2)
* -> (part = 1 and true) or (part = 2)
* -> (part = 1) or (part = 2)
*
* maybe eliminate failed in some special cases, e.g. (non_part + part) = 2.
* the key point is: if a predicate(return boolean type) only contains the uninterested slots, we can eliminate it.
*/
public class TryEliminateUninterestedPredicates extends DefaultExpressionRewriter<Context> {
private final Set<Slot> interestedSlots;
private final ExpressionRewriteContext expressionRewriteContext;
private TryEliminateUninterestedPredicates(Set<Slot> interestedSlots, CascadesContext cascadesContext) {
this.interestedSlots = interestedSlots;
this.expressionRewriteContext = new ExpressionRewriteContext(cascadesContext);
}
public static Expression rewrite(Expression expression, Set<Slot> interestedSlots,
CascadesContext cascadesContext) {
// before eliminate uninterested predicate, we must push down `Not` under CompoundPredicate
expression = expression.accept(new SimplifyNotExprRule(), null);
TryEliminateUninterestedPredicates rewriter = new TryEliminateUninterestedPredicates(
interestedSlots, cascadesContext);
return expression.accept(rewriter, new Context());
}
@Override
public Expression visit(Expression originExpr, Context parentContext) {
Context currentContext = new Context();
// postorder traversal
Expression expr = super.visit(originExpr, currentContext);
// process predicate
if (expr.getDataType().isBooleanType()) {
// if a predicate contains not only interested slots but also non-interested slots,
// we can not eliminate non-interested slots:
// e.g.
// not(uninterested slot b + interested slot a > 1)
// -> not(uninterested slot b + interested slot a > 1)
if (!currentContext.childrenContainsInterestedSlots && currentContext.childrenContainsNonInterestedSlots) {
// propagate true value up to eliminate uninterested slots,
// because we don't know the runtime value of the slots
// e.g.
// not(uninterested slot b > 1)
// -> not(true)
// -> true
expr = BooleanLiteral.TRUE;
} else {
// simplify the predicate expression, the interested slots may be eliminated too
// e.g.
// ((interested slot a) and not(uninterested slot b > 1)) or true
// -> ((interested slot a) and not(true)) or true
// -> ((interested slot a) and true) or true
// -> (interested slot a) or true
// -> true
expr = expr.accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext);
}
} else {
// ((uninterested slot b > 0) + 1) > 1
// -> (true + 1) > 1
// -> ((uninterested slot b > 0) + 1) > 1 (recover to origin expr because `true + 1` is not predicate)
// -> true (not contains interested slot but contains uninterested slot)
expr = originExpr;
}
parentContext.childrenContainsInterestedSlots |= currentContext.childrenContainsInterestedSlots;
parentContext.childrenContainsNonInterestedSlots |= currentContext.childrenContainsNonInterestedSlots;
return expr;
}
@Override
public Expression visitSlot(Slot slot, Context context) {
boolean isInterestedSlot = interestedSlots.contains(slot);
context.childrenContainsInterestedSlots |= isInterestedSlot;
context.childrenContainsNonInterestedSlots |= !isInterestedSlot;
return slot;
}
/** Context */
public static class Context {
private boolean childrenContainsInterestedSlots;
private boolean childrenContainsNonInterestedSlots;
}
}

View File

@ -0,0 +1,55 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.catalog.PartitionItem;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
/** UnknownPartitionEvaluator */
public class UnknownPartitionEvaluator implements OnePartitionEvaluator {
private final long partitionId;
private final PartitionItem partitionItem;
public UnknownPartitionEvaluator(long partitionId, PartitionItem partitionItem) {
this.partitionId = partitionId;
this.partitionItem = partitionItem;
}
@Override
public long getPartitionId() {
return partitionId;
}
@Override
public List<Map<Slot, PartitionSlotInput>> getOnePartitionInputs() {
return ImmutableList.of(ImmutableMap.of());
}
@Override
public Expression evaluate(Expression expression, Map<Slot, PartitionSlotInput> currentInputs) {
// do not prune
return expression;
}
}

View File

@ -17,46 +17,22 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.PartitionInfo;
import org.apache.doris.catalog.PartitionItem;
import org.apache.doris.catalog.PartitionType;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.rewrite.rules.PartitionPruner;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.planner.ColumnBound;
import org.apache.doris.planner.ColumnRange;
import org.apache.doris.planner.ListPartitionPrunerV2;
import org.apache.doris.planner.PartitionPruner;
import org.apache.doris.planner.RangePartitionPrunerV2;
import org.apache.doris.planner.ScanNode.ColumnRanges;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Range;
import com.google.common.collect.Sets;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Used to prune partition of olap scan, should execute after SwapProjectAndFilter, MergeConsecutiveFilters,
@ -71,112 +47,25 @@ public class PruneOlapScanPartition extends OneRewriteRuleFactory {
LogicalOlapScan scan = filter.child();
OlapTable table = scan.getTable();
Set<String> partitionColumnNameSet = Utils.execWithReturnVal(table::getPartitionColumnNames);
PartitionInfo partitionInfo = table.getPartitionInfo();
if (partitionColumnNameSet.isEmpty()) {
return ctx.root;
}
Set<Expression> expressionList = filter.getConjuncts();
// TODO: Process all partition column for now, better to process required column only.
Map<String, ColumnRange> columnNameToRange = Maps.newHashMap();
for (String colName : partitionColumnNameSet) {
ColumnRange columnRange = createColumnRange(colName, expressionList);
columnNameToRange.put(colName, columnRange);
return filter;
}
Map<Long, PartitionItem> keyItemMap = partitionInfo.getIdToItem(false);
PartitionPruner partitionPruner = partitionInfo.getType().equals(PartitionType.RANGE)
? new RangePartitionPrunerV2(keyItemMap,
partitionInfo.getPartitionColumns(), columnNameToRange) : new ListPartitionPrunerV2(keyItemMap,
partitionInfo.getPartitionColumns(), columnNameToRange);
Collection<Long> selectedPartitionId = Utils.execWithReturnVal(partitionPruner::prune);
List<Long> manuallySpecifiedPartitions = scan.getManuallySpecifiedPartitions();
if (!CollectionUtils.isEmpty(manuallySpecifiedPartitions)) {
selectedPartitionId.retainAll(manuallySpecifiedPartitions);
}
LogicalOlapScan rewrittenScan =
scan.withSelectedPartitionIds(new ArrayList<>(selectedPartitionId));
Map<String, Slot> scanOutput = scan.getOutput()
.stream()
.collect(Collectors.toMap(slot -> slot.getName().toLowerCase(), Function.identity()));
PartitionInfo partitionInfo = table.getPartitionInfo();
List<Slot> partitionSlots = partitionInfo.getPartitionColumns()
.stream()
.map(column -> scanOutput.get(column.getName().toLowerCase()))
.collect(Collectors.toList());
List<Long> prunedPartitions = PartitionPruner.prune(
partitionSlots, filter.getPredicate(), partitionInfo, ctx.cascadesContext);
LogicalOlapScan rewrittenScan = scan.withSelectedPartitionIds(prunedPartitions);
return new LogicalFilter<>(filter.getConjuncts(), rewrittenScan);
}).toRule(RuleType.OLAP_SCAN_PARTITION_PRUNE);
}
private ColumnRange createColumnRange(String colName, Set<Expression> expressionList) {
ColumnRange result = ColumnRange.create();
for (Expression expression : expressionList) {
Set<SlotReference> slotReferences = expression.collect(SlotReference.class::isInstance);
if (slotReferences.size() != 1 || !slotReferences.iterator().next().getName().equals(colName)) {
continue;
}
if (expression instanceof Or) {
List<Expression> disjunctiveList = ExpressionUtils.extractDisjunction(expression);
if (disjunctiveList.isEmpty()) {
continue;
}
List<Range<ColumnBound>> disjunctiveRanges = Lists.newArrayList();
Set<Boolean> hasIsNull = Sets.newHashSet();
boolean allMatch = disjunctiveList.stream().allMatch(e -> {
ColumnRanges ranges = exprToRanges(e, colName);
switch (ranges.type) {
case IS_NULL:
hasIsNull.add(true);
return true;
case CONVERT_SUCCESS:
disjunctiveRanges.addAll(ranges.ranges);
return true;
case CONVERT_FAILURE:
default:
return false;
}
});
if (allMatch && !(disjunctiveRanges.isEmpty() && hasIsNull.isEmpty())) {
result.intersect(disjunctiveRanges);
result.setHasDisjunctiveIsNull(!hasIsNull.isEmpty());
}
} else {
ColumnRanges ranges = exprToRanges(expression, colName);
switch (ranges.type) {
case IS_NULL:
result.setHasConjunctiveIsNull(true);
break;
case CONVERT_SUCCESS:
result.intersect(ranges.ranges);
break;
case CONVERT_FAILURE:
default:
break;
}
}
}
return result;
}
private ColumnRanges exprToRanges(Expression expression, String colName) {
// TODO: process in/is null expression
if (!(expression instanceof ComparisonPredicate)) {
return ColumnRanges.createFailure();
}
List<Range<ColumnBound>> result = Lists.newArrayList();
ComparisonPredicate comparisonPredicate = (ComparisonPredicate) expression;
Expression rightChild = comparisonPredicate.child(1);
if (rightChild == null || !rightChild.isConstant() || !(rightChild instanceof Literal)) {
return ColumnRanges.createFailure();
}
LiteralExpr value = ((Literal) rightChild).toLegacyLiteral();
if (expression instanceof EqualTo) {
ColumnBound bound = ColumnBound.of(value);
result.add(Range.closed(bound, bound));
} else if (expression instanceof GreaterThanEqual) {
result.add(Range.atLeast(ColumnBound.of(value)));
} else if (expression instanceof GreaterThan) {
result.add(Range.greaterThan(ColumnBound.of(value)));
} else if (expression instanceof LessThan) {
result.add(Range.lessThan(ColumnBound.of(value)));
} else if (expression instanceof LessThanEqual) {
result.add(Range.atMost(ColumnBound.of(value)));
}
if (result.isEmpty()) {
return ColumnRanges.createFailure();
} else {
return ColumnRanges.create(result);
}
}
}

View File

@ -39,6 +39,7 @@ import java.time.temporal.TemporalAccessor;
* Date literal in Nereids.
*/
public class DateLiteral extends Literal {
public static final String JAVA_DATE_FORMAT = "yyyy-MM-dd";
protected static DateTimeFormatter DATE_FORMATTER = null;
protected static DateTimeFormatter DATE_FORMATTER_TWO_DIGIT = null;

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -271,6 +272,51 @@ public abstract class Literal extends Expression implements LeafExpression, Comp
return this instanceof StringLiteral || this instanceof CharLiteral || this instanceof VarcharLiteral;
}
/** fromLegacyLiteral */
public static Literal fromLegacyLiteral(LiteralExpr literalExpr, Type type) {
DataType dataType = DataType.fromCatalogType(type);
if (literalExpr instanceof org.apache.doris.analysis.MaxLiteral) {
return new MaxLiteral(dataType);
}
String stringValue = literalExpr.getStringValue();
if (dataType.isTinyIntType()) {
return new TinyIntLiteral(Byte.valueOf(stringValue).byteValue());
} else if (dataType.isSmallIntType()) {
return new SmallIntLiteral(Short.valueOf(stringValue).shortValue());
} else if (dataType.isIntegerType()) {
return new IntegerLiteral(Integer.valueOf(stringValue).intValue());
} else if (dataType.isBigIntType()) {
return new BigIntLiteral(Long.valueOf(stringValue).longValue());
} else if (dataType.isLargeIntType()) {
return new LargeIntLiteral(new BigInteger(stringValue));
} else if (dataType.isStringType()) {
return new StringLiteral(stringValue);
} else if (dataType.isCharType()) {
return new CharLiteral(stringValue, ((CharType) dataType).getLen());
} else if (dataType.isVarcharType()) {
return new VarcharLiteral(stringValue, ((VarcharType) dataType).getLen());
} else if (dataType.isFloatType()) {
return new FloatLiteral(Float.valueOf(stringValue));
} else if (dataType.isDoubleType()) {
return new DoubleLiteral(Double.valueOf(stringValue));
} else if (dataType.isDecimalV2Type()) {
return new DecimalLiteral((DecimalV2Type) dataType, new BigDecimal(stringValue));
} else if (dataType.isDecimalV3Type()) {
return new DecimalV3Literal((DecimalV3Type) dataType, new BigDecimal(stringValue));
} else if (dataType.isDateType()) {
return new DateLiteral(stringValue);
} else if (dataType.isDateV2Type()) {
return new DateV2Literal(stringValue);
} else if (dataType.isDateTimeType()) {
return new DateTimeLiteral(stringValue);
} else if (dataType.isDateTimeV2Type()) {
return new DateTimeV2Literal(stringValue);
} else {
throw new AnalysisException("Unsupported convert the " + literalExpr.getType()
+ " of legacy literal to nereids literal");
}
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -295,7 +341,7 @@ public abstract class Literal extends Expression implements LeafExpression, Comp
public abstract LiteralExpr toLegacyLiteral();
public boolean isStringLiteral() {
public boolean isStringLikeLiteral() {
return dataType.isStringLikeType();
}
}

View File

@ -0,0 +1,49 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.literal;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.types.DataType;
/** MaxLiteral */
public class MaxLiteral extends Literal {
public MaxLiteral(DataType dataType) {
super(dataType);
}
@Override
public Object getValue() {
throw new AnalysisException("Can not get value from max literal");
}
@Override
public LiteralExpr toLegacyLiteral() {
return org.apache.doris.analysis.MaxLiteral.MAX_VALUE;
}
@Override
public String toSql() {
return "MAX_VALUE";
}
@Override
public String toString() {
return "MAX_VALUE";
}
}

View File

@ -30,10 +30,15 @@ public abstract class DefaultExpressionRewriter<C> extends ExpressionVisitor<Exp
@Override
public Expression visit(Expression expr, C context) {
return rewrite(this, expr, context);
}
/** rewrite */
public static final <C> Expression rewrite(ExpressionVisitor<Expression, C> rewriter, Expression expr, C context) {
List<Expression> newChildren = new ArrayList<>();
boolean hasNewChildren = false;
for (Expression child : expr.children()) {
Expression newChild = child.accept(this, context);
Expression newChild = child.accept(rewriter, context);
if (newChild != child) {
hasNewChildren = true;
}

View File

@ -35,6 +35,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Utils for Nereids.
@ -255,4 +256,28 @@ public class Utils {
}
Preconditions.checkState(false, "item not found in list");
}
/** allCombinations */
public static <T> List<List<T>> allCombinations(List<List<T>> lists) {
int size = lists.size();
if (size == 0) {
return ImmutableList.of();
}
List<T> first = lists.get(0);
if (size == 1) {
return first
.stream()
.map(ImmutableList::of)
.collect(ImmutableList.toImmutableList());
}
List<List<T>> rest = lists.subList(1, size);
List<List<T>> combinationWithoutFirst = allCombinations(rest);
return first.stream()
.flatMap(firstValue -> combinationWithoutFirst.stream()
.map(restList ->
Stream.concat(Stream.of(firstValue), restList.stream())
.collect(ImmutableList.toImmutableList())
)
).collect(ImmutableList.toImmutableList());
}
}

View File

@ -259,6 +259,8 @@ public class SessionVariable implements Serializable, Writable {
public static final String PARTITIONED_HASH_JOIN_ROWS_THRESHOLD = "partitioned_hash_join_rows_threshold";
public static final String PARTITIONED_HASH_AGG_ROWS_THRESHOLD = "partitioned_hash_agg_rows_threshold";
public static final String PARTITION_PRUNING_EXPAND_THRESHOLD = "partition_pruning_expand_threshold";
public static final String ENABLE_SHARE_HASH_TABLE_FOR_BROADCAST_JOIN
= "enable_share_hash_table_for_broadcast_join";
@ -721,6 +723,9 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = PARTITIONED_HASH_AGG_ROWS_THRESHOLD, fuzzy = true)
public int partitionedHashAggRowsThreshold = 0;
@VariableMgr.VarAttr(name = PARTITION_PRUNING_EXPAND_THRESHOLD, fuzzy = true)
public int partitionPruningExpandThreshold = 10;
@VariableMgr.VarAttr(name = ENABLE_SHARE_HASH_TABLE_FOR_BROADCAST_JOIN, fuzzy = true)
public boolean enableShareHashTableForBroadcastJoin = true;

View File

@ -17,157 +17,370 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.analysis.IntLiteral;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.PartitionItem;
import org.apache.doris.catalog.PartitionKey;
import org.apache.doris.catalog.RangePartitionInfo;
import org.apache.doris.catalog.RangePartitionItem;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Range;
import mockit.Expectations;
import mockit.Mocked;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
class PruneOlapScanPartitionTest extends TestWithFeService implements MemoPatternMatchSupported {
class PruneOlapScanPartitionTest implements MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
useDatabase("test");
@Test
void testOlapScanPartitionWithSingleColumnCase(@Mocked OlapTable olapTable) throws Exception {
List<Column> columnNameList = new ArrayList<>();
columnNameList.add(new Column("col1", Type.INT.getPrimitiveType()));
columnNameList.add(new Column("col2", Type.INT.getPrimitiveType()));
Map<Long, PartitionItem> keyItemMap = new HashMap<>();
PartitionKey k0 = new PartitionKey();
k0.pushColumn(new IntLiteral(0), Type.INT.getPrimitiveType());
PartitionKey k1 = new PartitionKey();
k1.pushColumn(new IntLiteral(5), Type.INT.getPrimitiveType());
keyItemMap.put(0L, new RangePartitionItem(Range.range(k0, BoundType.CLOSED, k1, BoundType.OPEN)));
PartitionKey k2 = new PartitionKey();
k2.pushColumn(new IntLiteral(5), Type.INT.getPrimitiveType());
PartitionKey k3 = new PartitionKey();
k3.pushColumn(new IntLiteral(10), Type.INT.getPrimitiveType());
keyItemMap.put(1L, new RangePartitionItem(Range.range(k2, BoundType.CLOSED, k3, BoundType.OPEN)));
RangePartitionInfo rangePartitionInfo = new RangePartitionInfo(columnNameList);
Deencapsulation.setField(rangePartitionInfo, "idToItem", keyItemMap);
new Expectations() {{
olapTable.getPartitionInfo();
result = rangePartitionInfo;
olapTable.getPartitionColumnNames();
result = rangePartitionInfo.getPartitionColumns().stream().map(c -> c.getName().toLowerCase())
.collect(Collectors.toSet());
olapTable.getName();
result = "tbl";
}};
LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
SlotReference slotRef = new SlotReference("col1", IntegerType.INSTANCE);
Expression expression = new LessThan(slotRef, new IntegerLiteral(4));
LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(ImmutableSet.of(expression), scan);
createTable("create table test_list_parts(id int, part int not null) "
+ "partition by list(part) ("
+ " partition p1 (('1'), ('4'), ('7')),"
+ " partition p2 (('8'), ('9'), ('5')),"
+ " partition p3 (('11'), ('0'), ('6'))"
+ ") "
+ "distributed by hash(id) "
+ "properties ('replication_num'='1')");
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
.applyTopDown(new PruneOlapScanPartition())
.matches(
logicalFilter(
logicalOlapScan().when(
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L)
)
);
createTable("create table test_range_parts(id int, part int) "
+ "partition by range(part) ("
+ " partition p1 values[('1'), ('2')),"
+ " partition p2 values[('2'), ('3')),"
+ " partition p3 values[('3'), ('4')),"
+ " partition p4 values[('4'), ('5'))"
+ ") "
+ "distributed by hash(id) "
+ "properties ('replication_num'='1')");
Expression lessThan0 = new LessThan(slotRef, new IntegerLiteral(0));
Expression greaterThan6 = new GreaterThan(slotRef, new IntegerLiteral(6));
Or lessThan0OrGreaterThan6 = new Or(lessThan0, greaterThan6);
filter = new LogicalFilter<>(ImmutableSet.of(lessThan0OrGreaterThan6), scan);
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
String singleColumnPartitionTable =
"CREATE TABLE `test`.`t1` (\n"
+ " `dt` int(11) NULL COMMENT \"\",\n"
+ " `k1` int(11) NULL COMMENT \"\",\n"
+ " `k2` int(11) NULL COMMENT \"\",\n"
+ " `k3` int(11) NULL COMMENT \"\",\n"
+ " `k4` int(11) NULL COMMENT \"\"\n"
+ ") "
+ "DUPLICATE KEY(`dt`, `k1`, `k2`, `k3`, `k4`)\n"
+ "PARTITION BY RANGE(`dt`)\n"
+ "(PARTITION p20211121 VALUES LESS THAN (\"20211121\"),\n"
+ "PARTITION p20211122 VALUES [(\"20211121\"), (\"20211122\")),\n"
+ "PARTITION p20211123 VALUES [(\"20211122\"), (\"20211123\")),\n"
+ "PARTITION p20211124 VALUES [(\"20211123\"), (\"20211124\")),\n"
+ "PARTITION p20211125 VALUES [(\"20211124\"), (\"20211125\")),\n"
+ "PARTITION p20211126 VALUES [(\"20211125\"), (\"20211126\")),\n"
+ "PARTITION p20211127 VALUES [(\"20211126\"), (\"20211127\")),\n"
+ "PARTITION p20211128 VALUES [(\"20211127\"), (\"20211128\")))\n"
+ "DISTRIBUTED BY HASH(`k1`) BUCKETS 60\n"
+ "PROPERTIES('replication_num' = '1');";
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
.applyTopDown(new PruneOlapScanPartition())
.matches(
logicalFilter(
logicalOlapScan().when(
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 1L)
)
);
String notNullSingleColumnPartitionTable =
"CREATE TABLE `test`.`single_not_null` (\n"
+ " `dt` int(11) NULL COMMENT \"\",\n"
+ " `k1` int(11) NULL COMMENT \"\",\n"
+ " `k2` int(11) NULL COMMENT \"\",\n"
+ " `k3` int(11) NULL COMMENT \"\",\n"
+ " `k4` int(11) NULL COMMENT \"\"\n"
+ ") "
+ "DUPLICATE KEY(`dt`, `k1`, `k2`, `k3`, `k4`)\n"
+ "PARTITION BY RANGE(`dt`)\n"
+ "(PARTITION p20211122 VALUES [(\"20211121\"), (\"20211122\")),\n"
+ "PARTITION p20211123 VALUES [(\"20211122\"), (\"20211123\")),\n"
+ "PARTITION p20211124 VALUES [(\"20211123\"), (\"20211124\")),\n"
+ "PARTITION p20211125 VALUES [(\"20211124\"), (\"20211125\")),\n"
+ "PARTITION p20211126 VALUES [(\"20211125\"), (\"20211126\")),\n"
+ "PARTITION p20211127 VALUES [(\"20211126\"), (\"20211127\")),\n"
+ "PARTITION p20211128 VALUES [(\"20211127\"), (\"20211128\")))\n"
+ "DISTRIBUTED BY HASH(`k1`) BUCKETS 60\n"
+ "PROPERTIES('replication_num' = '1');";
Expression greaterThanEqual0 =
new GreaterThanEqual(
slotRef, new IntegerLiteral(0));
Expression lessThanEqual5 =
new LessThanEqual(slotRef, new IntegerLiteral(5));
scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
filter = new LogicalFilter<>(ImmutableSet.of(greaterThanEqual0, lessThanEqual5), scan);
String multipleColumnsPartitionTable =
"CREATE TABLE `test`.`t2` (\n"
+ " `k1` int(11) NULL COMMENT \"\",\n"
+ " `k2` int(11) NULL COMMENT \"\",\n"
+ " `k3` int(11) NULL COMMENT \"\",\n"
+ " `k4` int(11) NULL COMMENT \"\",\n"
+ " `k5` int(11) NULL COMMENT \"\"\n"
+ ") \n"
+ "PARTITION BY RANGE(`k1`, `k2`)\n"
+ "(PARTITION p1 VALUES LESS THAN (\"3\", \"1\"),\n"
+ "PARTITION p2 VALUES [(\"3\", \"1\"), (\"7\", \"10\")),\n"
+ "PARTITION p3 VALUES [(\"7\", \"10\"), (\"8\", \"5\")),\n"
+ "PARTITION p4 VALUES [(\"10\", \"10\"), (\"12\", \"5\")),\n"
+ "PARTITION p5 VALUES [(\"15\", \"6\"), (\"20\", \"11\")),\n"
+ "PARTITION p6 VALUES [(\"20\", \"11\"), (\"22\", \"3\")),\n"
+ "PARTITION p7 VALUES [(\"23\", \"3\"), (\"23\", \"4\")),\n"
+ "PARTITION p8 VALUES [(\"23\", \"4\"), (\"23\", \"20\")),\n"
+ "PARTITION p9 VALUES [(\"24\", \"1\"), (\"25\", \"9\")))\n"
+ "DISTRIBUTED BY HASH(`k1`) BUCKETS 10\n"
+ "PROPERTIES ('replication_num' = '1');";
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
.applyTopDown(new PruneOlapScanPartition())
.matches(
logicalFilter(
logicalOlapScan().when(
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L)
.when(olapScan -> olapScan.getSelectedPartitionIds().size() == 2)
)
);
String notNullMultipleColumnsPartitionTable =
"CREATE TABLE `test`.`multi_not_null` (\n"
+ " `k1` int(11) NULL COMMENT \"\",\n"
+ " `k2` int(11) NULL COMMENT \"\",\n"
+ " `k3` int(11) NULL COMMENT \"\",\n"
+ " `k4` int(11) NULL COMMENT \"\",\n"
+ " `k5` int(11) NULL COMMENT \"\"\n"
+ ") \n"
+ "PARTITION BY RANGE(`k1`, `k2`)\n"
+ "(PARTITION p1 VALUES [(\"3\", \"1\"), (\"3\", \"3\")),\n"
+ "PARTITION p2 VALUES [(\"4\", \"2\"), (\"4\", \"6\")))\n"
+ "DISTRIBUTED BY HASH(`k1`) BUCKETS 10\n"
+ "PROPERTIES ('replication_num' = '1');";
createTables(singleColumnPartitionTable,
notNullSingleColumnPartitionTable,
multipleColumnsPartitionTable,
notNullMultipleColumnsPartitionTable);
}
@Test
void testOlapScanPartitionPruneWithMultiColumnCase(@Mocked OlapTable olapTable) throws Exception {
List<Column> columnNameList = new ArrayList<>();
columnNameList.add(new Column("col1", Type.INT.getPrimitiveType()));
columnNameList.add(new Column("col2", Type.INT.getPrimitiveType()));
Map<Long, PartitionItem> keyItemMap = new HashMap<>();
PartitionKey k0 = new PartitionKey();
k0.pushColumn(new IntLiteral(1), Type.INT.getPrimitiveType());
k0.pushColumn(new IntLiteral(10), Type.INT.getPrimitiveType());
PartitionKey k1 = new PartitionKey();
k1.pushColumn(new IntLiteral(4), Type.INT.getPrimitiveType());
k1.pushColumn(new IntLiteral(5), Type.INT.getPrimitiveType());
keyItemMap.put(0L, new RangePartitionItem(Range.range(k0, BoundType.CLOSED, k1, BoundType.OPEN)));
RangePartitionInfo rangePartitionInfo = new RangePartitionInfo(columnNameList);
Deencapsulation.setField(rangePartitionInfo, "idToItem", keyItemMap);
new Expectations() {{
olapTable.getPartitionInfo();
result = rangePartitionInfo;
olapTable.getPartitionColumnNames();
result = rangePartitionInfo.getPartitionColumns().stream().map(c -> c.getName().toLowerCase())
.collect(Collectors.toSet());
olapTable.getName();
result = "tbl";
}};
LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable);
Expression left = new LessThan(new SlotReference("col1", IntegerType.INSTANCE), new IntegerLiteral(4));
Expression right = new GreaterThan(new SlotReference("col2", IntegerType.INSTANCE), new IntegerLiteral(11));
LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(ImmutableSet.of(left, right), scan);
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
.applyTopDown(new PruneOlapScanPartition())
.matches(
logicalFilter(
logicalOlapScan()
.when(
olapScan -> olapScan.getSelectedPartitionIds().iterator().next() == 0L)
)
);
void testOlapScanPartitionWithSingleColumnCase() throws Exception {
createTable("create table testOlapScanPartitionWithSingleColumnCase("
+ " id int not null,"
+ " col1 int not null"
+ " ) "
+ "partition by range(col1) ("
+ " partition p1 values[('0'), ('5')),"
+ " partition p2 values[('5'), ('10'))"
+ ") "
+ "distributed by hash(id) "
+ "properties ('replication_num'='1')");
test("testOlapScanPartitionWithSingleColumnCase", "col1 < 4", 1);
test("testOlapScanPartitionWithSingleColumnCase", "col1 < 0 or col1 > 6", 1);
test("testOlapScanPartitionWithSingleColumnCase", "col1 >= 0 and col1 <= 5", 2);
}
@Test
void testOlapScanPartitionPruneWithMultiColumnCase() throws Exception {
createTable("create table testOlapScanPartitionPruneWithMultiColumnCase("
+ " id int not null,"
+ " col1 int not null,"
+ " col2 int not null"
+ " ) "
+ "partition by range(col1, col2) ("
+ " partition p1 values[('1', '10'), ('4', '5'))"
+ ") "
+ "distributed by hash(id) "
+ "properties ('replication_num'='1')");
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 = 4", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 = 1", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 = 2", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 < 1", 0);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 >= 4", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 > 4", 0);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col2 = 10", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col2 = 5", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col2 = 100", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col2 = -1", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col2 < 10", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col2 >= 5", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col2 > 5", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 < 4 and col2 > 11", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 < 2 and col2 > 11", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 < 2 and col2 <= 10", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 < 2 and col2 < 10", 0);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 < 2 and col2 <= 10", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 >= 4 and col2 > 10", 0);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 < 4 or col2 > 5", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 < 4 or col2 > 3", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 < 1 or col2 > 5", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "col1 <= 1 or col2 >= 10", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "cast(col1 as bigint) = 1", 1);
test("testOlapScanPartitionPruneWithMultiColumnCase", "cast(col1 as bigint) = 5", 0);
test("testOlapScanPartitionPruneWithMultiColumnCase", "cast(col1 as bigint) + 1 = 5", 1);
}
@Test
public void prunePartitionWithOrPredicate() {
test("test_list_parts", "(part = 9 and id <= 500) or (part = 3)", 1);
test("test_range_parts", "(part = 1 and id <= 500) or (part = 3)", 2);
}
@Test
public void canNotPruneComplexPredicate() {
test("test_range_parts", "(part = 10) or (part + id = 1)", 4);
test("test_range_parts", "(part + id = 1) and (part = 4)", 1);
}
@Test
public void pruneMultiColumnListPartition() throws Exception {
createTable("create table test_multi_list_parts(id int, part1 int not null, part2 varchar(32) not null) "
+ "partition by list(part1, part2) ("
+ " partition p1 (('1', 'd'), ('3', 'a')),"
+ " partition p2 (('4', 'c'), ('6', 'f'))"
+ ") "
+ "distributed by hash(id) "
+ "properties ('replication_num'='1')");
test("test_multi_list_parts", "part1 = 1 and part2 < 'd'", 0);
}
@Test
void testOlapScanPartitionPruneWithNonEnumerableRange() throws Exception {
String tableName = "testOlapScanPartitionPruneWithNonEnumerableRange";
createTable("create table " + tableName + "("
+ " id int not null,"
+ " col1 datetime not null"
+ " ) "
+ "partition by range(col1) ("
+ " partition p1 values[('2023-03-21 00:00:00'), ('2023-03-21 23:59:59'))"
+ ") "
+ "distributed by hash(id) "
+ "properties ('replication_num'='1')");
test(tableName, "date(col1) = '2023-03-21'", 1);
test(tableName, "date(col1) = '2023-03-22'", 0);
}
@Test
void testMaxValue() throws Exception {
createTable("CREATE TABLE IF NOT EXISTS `test_basic_agg` (\n"
+ " `k1` tinyint(4) NULL COMMENT \"\",\n"
+ " `k2` smallint(6) NULL COMMENT \"\",\n"
+ " `k3` int(11) NULL COMMENT \"\",\n"
+ " `k4` bigint(20) NULL COMMENT \"\",\n"
+ " `k5` decimal(9, 3) NULL COMMENT \"\",\n"
+ " `k6` char(5) NULL COMMENT \"\",\n"
+ " `k10` date NULL COMMENT \"\",\n"
+ " `k11` datetime NULL COMMENT \"\",\n"
+ " `k7` varchar(20) NULL COMMENT \"\",\n"
+ " `k8` double MAX NULL COMMENT \"\",\n"
+ " `k9` float SUM NULL COMMENT \"\"\n"
+ ") ENGINE=OLAP\n"
+ "AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`, `k10`, `k11`, `k7`)\n"
+ "COMMENT \"OLAP\"\n"
+ "PARTITION BY RANGE(`k1`)\n"
+ "(PARTITION p1 VALUES [(\"-128\"), (\"-64\")),\n"
+ "PARTITION p2 VALUES [(\"-64\"), (\"0\")),\n"
+ "PARTITION p3 VALUES [(\"0\"), (\"64\")),\n"
+ "PARTITION p4 VALUES [(\"64\"), (MAXVALUE)))\n"
+ "DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n"
+ "PROPERTIES (\n"
+ "\"replication_allocation\" = \"tag.location.default: 1\",\n"
+ "\"in_memory\" = \"false\",\n"
+ "\"storage_format\" = \"V2\"\n"
+ ");");
// TODO: support like function to prune partition
test("test_basic_agg", " 1998 like '1%'", 4);
test("test_basic_agg", " '1998' like '1%'", 4);
test("test_basic_agg", " 2998 like '1%'", 4);
test("test_basic_agg", " '2998' like '1%'", 4);
test("test_basic_agg", " 199.8 like '1%'", 4);
test("test_basic_agg", "'199.8' like '1%'", 4);
test("test_basic_agg", " 299.8 like '1%'", 4);
test("test_basic_agg", "'299.8' like '1%'", 4);
}
@Test
void legacyTests() {
// 1. Single partition column
// no filters
test("t1", "", 8);
// equal to
test("t1", "dt=20211122", 1);
// less than
test("t1", "dt<20211122", 2);
// less than or equal
test("t1", "dt<=20211122", 3);
// greater than
test("t1", "dt>20211122", 5); // legacy return 6
// greater than or equal
test("t1", "dt>=20211122", 6);
// in
test("t1", "dt in (20211124, 20211126, 20211122)", 3);
// is null
test("t1", "dt is null", 1);
test("`single_not_null`", "dt is null", 0);
// not equal to
test("t1", "dt!=20211122", 7); //legacy return 8
// 2. Multiple partition columns
// no filters
test("t2", "", 9);
// equal to
test("t2", "k1=7", 2);
test("t2", "k2=7", 7); // legacy return 9
// less than
test("t2", "k1<7", 2);
test("t2", "k2<7", 9);
// less than or equal
test("t2", "k1<=7", 3);
test("t2", "k2>7", 8); // legacy return 9
// greater than or equal
test("t2", "k1>=7", 8);
test("t2", "k2>=7", 8); // legacy return 9
// in
test("t2", "k1 in (7,9,16)", 3);
test("t2", "k2 in (7,9,16)", 8); // legacy return 9
// is null
test("t2", "k1 is null", 1);
test("t2", "k2 is null", 7); // legacy return 9
test("multi_not_null", "k1 is null", 0);
test("multi_not_null", "k2 is null", 0); // legacy return 2
// not equal to
test("t2", "k1!=23", 7); // legacy return 9
test("t2", "k2!=23", 9);
// 3. Conjunctive predicates
// equal to and other predicates
test("t2", "k1=23 and k2=5", 1);
test("t2", "k1=23 and k2>5", 1);
// in and other equal predicates
test("t2", "k1 in (3, 10, 13) and k2>10", 2);
// is null and other predicates
test("t2", "k1 > 10 and k1 is null", 0);
test("t2", "k1 is null and k1 > 10", 0);
test("multi_not_null", "k1 > 10 and k1 is null", 0);
// others predicates combination
test("t2", "k1 > 10 and k2 < 4", 5); // legacy return 6
test("t2", "k1 >10 and k1 < 10 and (k1=11 or k1=12)", 0);
test("t2", "k1 > 20 and k1 < 7 and k1 = 10", 0);
// 4. Disjunctive predicates
test("t2", "k1=10 or k1=23", 3);
test("t2", "(k1=10 or k1=23) and (k2=4 or k2=5)", 1);
test("t2", "(k1=10 or k1=23) and (k2=4 or k2=11)", 2);
test("t2", "(k1=10 or k1=23) and (k2=3 or k2=4 or k2=11)", 3);
test("t1", "dt=20211123 or dt=20211124", 2);
test("t1", "((dt=20211123 and k1=1) or (dt=20211125 and k1=3))", 2);
// maybe something goes wrong with ExtractCommonFactorsRule.
test("t1", "((dt=20211123 and k1=1) or (dt=20211125 and k1=3)) and k2>0",
2);
test("t2", "k1 > 10 or k2 < 1", 9);
// add some cases for CompoundPredicate
test("t1", "(dt >= 20211121 and dt <= 20211122) or (dt >= 20211123 and dt <= 20211125)",
5);
test("t1", "(dt between 20211121 and 20211122) or (dt between 20211123 and 20211125)",
5);
test("t1", "(dt between 20211121 and 20211122) or dt is null or (dt between 20211123 and 20211125)",
6);
}
private void test(String table, String filter, int expectScanPartitionNum) {
PlanChecker planChecker = PlanChecker.from(connectContext)
.analyze("select * from " + table + (filter.isEmpty() ? "" : " where " + filter))
.rewrite()
.printlnTree();
if (expectScanPartitionNum == 0) {
try {
planChecker.matches(logicalEmptyRelation());
return;
} catch (Throwable t) {
// do nothing
}
}
planChecker.matches(
logicalOlapScan().when(scan -> {
Assertions.assertEquals(expectScanPartitionNum, scan.getSelectedPartitionIds().size());
return true;
})
);
}
}

View File

@ -154,7 +154,7 @@ suite("test_aggregate_collect") {
${tableName}
"""
qt_select """
order_qt_select """
SELECT
size(collect_set(c_bool,1)),
size(collect_set(c_tinyint,1)),