[fix](Nereids) not equals and hashCode should contains generate flag (#31123)

This commit is contained in:
morrySnow
2024-02-22 19:03:25 +08:00
committed by yiguolei
parent f1a7f9a70f
commit 22efb1cb7a
12 changed files with 119 additions and 54 deletions

View File

@ -37,6 +37,7 @@ import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NonNullable;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable;
@ -500,9 +501,16 @@ public abstract class AbstractMaterializedViewRule implements ExplorationRuleFac
CascadesContext cascadesContext) {
Set<Expression> queryPulledUpPredicates = queryPredicates.stream()
.flatMap(expr -> ExpressionUtils.extractConjunction(expr).stream())
.map(expr -> {
// NOTICE inferNotNull generate Not with isGeneratedIsNotNull = false,
// so, we need set this flag to false before comparison.
if (expr instanceof Not) {
return ((Not) expr).withGeneratedIsNotNull(false);
}
return expr;
})
.collect(Collectors.toSet());
Set<Expression> nullRejectPredicates = ExpressionUtils.inferNotNull(queryPulledUpPredicates,
cascadesContext);
Set<Expression> nullRejectPredicates = ExpressionUtils.inferNotNull(queryPulledUpPredicates, cascadesContext);
Set<Expression> queryUsedNeedRejectNullSlotsViewBased = nullRejectPredicates.stream()
.map(expression -> TypeUtils.isNotNull(expression).orElse(null))
.filter(Objects::nonNull)

View File

@ -53,19 +53,23 @@ public class EliminateNotNull implements RewriteRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
logicalFilter()
.when(filter -> filter.getConjuncts().stream().anyMatch(expr -> expr.isGeneratedIsNotNull))
.thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
List<Expression> predicates = removeGeneratedNotNull(filter.getConjuncts(),
ctx.cascadesContext);
if (predicates.size() == filter.getConjuncts().size()) {
return null;
}
return PlanUtils.filterOrSelf(ImmutableSet.copyOf(predicates), filter.child());
}).toRule(RuleType.ELIMINATE_NOT_NULL),
innerLogicalJoin()
.when(join -> join.getOtherJoinConjuncts().stream().anyMatch(expr -> expr.isGeneratedIsNotNull))
.thenApply(ctx -> {
LogicalJoin<Plan, Plan> join = ctx.root;
List<Expression> newOtherJoinConjuncts = removeGeneratedNotNull(
join.getOtherJoinConjuncts(), ctx.cascadesContext);
if (newOtherJoinConjuncts.size() == join.getOtherJoinConjuncts().size()) {
return null;
}
return join.withJoinConjuncts(join.getHashJoinConjuncts(), newOtherJoinConjuncts);
})
.toRule(RuleType.ELIMINATE_NOT_NULL)
@ -81,7 +85,8 @@ public class EliminateNotNull implements RewriteRuleFactory {
Set<Expression> predicatesNotContainIsNotNull = Sets.newHashSet();
List<Slot> slotsFromIsNotNull = Lists.newArrayList();
exprs.stream()
.filter(expr -> !expr.isGeneratedIsNotNull) // remove generated `is not null`
.filter(expr -> !(expr instanceof Not)
|| !((Not) expr).isGeneratedIsNotNull()) // remove generated `is not null`
.forEach(expr -> {
Optional<Slot> notNullSlot = TypeUtils.isNotNull(expr);
if (notNullSlot.isPresent()) {

View File

@ -75,8 +75,7 @@ public class EliminateOuterJoin extends OneRewriteRuleFactory {
boolean conjunctsChanged = false;
if (!notNullSlots.isEmpty()) {
for (Slot slot : notNullSlots) {
Not isNotNull = new Not(new IsNull(slot));
isNotNull.isGeneratedIsNotNull = true;
Not isNotNull = new Not(new IsNull(slot), true);
conjunctsChanged |= conjuncts.add(isNotNull);
}
}
@ -135,13 +134,11 @@ public class EliminateOuterJoin extends OneRewriteRuleFactory {
private boolean createIsNotNullIfNecessary(EqualPredicate swapedEqualTo, Collection<Expression> container) {
boolean containerChanged = false;
if (swapedEqualTo.left().nullable()) {
Not not = new Not(new IsNull(swapedEqualTo.left()));
not.isGeneratedIsNotNull = true;
Not not = new Not(new IsNull(swapedEqualTo.left()), true);
containerChanged |= container.add(not);
}
if (swapedEqualTo.right().nullable()) {
Not not = new Not(new IsNull(swapedEqualTo.right()));
not.isGeneratedIsNotNull = true;
Not not = new Not(new IsNull(swapedEqualTo.right()), true);
containerChanged |= container.add(not);
}
return containerChanged;

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
@ -32,6 +33,9 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableSet;
import java.util.Collections;
import java.util.Set;
import java.util.stream.Collectors;
@ -55,12 +59,25 @@ public class InferAggNotNull extends OneRewriteRuleFactory {
LogicalAggregate<Plan> agg = ctx.root;
Set<Expression> exprs = agg.getAggregateFunctions().stream().flatMap(f -> f.children().stream())
.collect(Collectors.toSet());
Set<Expression> isNotNull = ExpressionUtils.inferNotNull(exprs, ctx.cascadesContext);
if (isNotNull.size() == 0 || (agg.child() instanceof Filter && isNotNull.equals(
((Filter) agg.child()).getConjuncts()))) {
Set<Expression> isNotNulls = ExpressionUtils.inferNotNull(exprs, ctx.cascadesContext);
Set<Expression> predicates = Collections.emptySet();
if ((agg.child() instanceof Filter)) {
predicates = ((Filter) agg.child()).getConjuncts();
}
ImmutableSet.Builder<Expression> needGenerateNotNullsBuilder = ImmutableSet.builder();
for (Expression isNotNull : isNotNulls) {
if (!predicates.contains(isNotNull)) {
isNotNull = ((Not) isNotNull).withGeneratedIsNotNull(true);
if (!predicates.contains(isNotNull)) {
needGenerateNotNullsBuilder.add(isNotNull);
}
}
}
Set<Expression> needGenerateNotNulls = needGenerateNotNullsBuilder.build();
if (needGenerateNotNulls.isEmpty()) {
return null;
}
return agg.withChildren(PlanUtils.filter(isNotNull, agg.child()).get());
return agg.withChildren(PlanUtils.filter(needGenerateNotNulls, agg.child()).get());
}).toRule(RuleType.INFER_AGG_NOT_NULL);
}
}

View File

@ -20,13 +20,14 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSet.Builder;
import com.google.common.collect.Streams;
import java.util.Set;
@ -43,18 +44,27 @@ public class InferFilterNotNull extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFilter()
.when(filter -> filter.getConjuncts().stream().noneMatch(expr -> expr.isGeneratedIsNotNull))
.when(filter -> filter.getConjuncts().stream()
.filter(Not.class::isInstance)
.map(Not.class::cast)
.noneMatch(Not::isGeneratedIsNotNull))
.thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
Set<Expression> predicates = filter.getConjuncts();
Set<Expression> isNotNull = ExpressionUtils.inferNotNull(predicates, ctx.cascadesContext);
if (isNotNull.isEmpty() || predicates.containsAll(isNotNull)) {
Set<Expression> isNotNulls = ExpressionUtils.inferNotNull(predicates, ctx.cascadesContext);
ImmutableSet.Builder<Expression> needGenerateNotNullsBuilder = ImmutableSet.builder();
for (Expression isNotNull : isNotNulls) {
if (!predicates.contains(isNotNull)) {
needGenerateNotNullsBuilder.add(((Not) isNotNull).withGeneratedIsNotNull(true));
}
}
Set<Expression> needGenerateNotNulls = needGenerateNotNullsBuilder.build();
if (needGenerateNotNulls.isEmpty()) {
return null;
}
Builder<Expression> builder = ImmutableSet.<Expression>builder()
.addAll(predicates)
.addAll(isNotNull);
return PlanUtils.filter(builder.build(), filter.child()).get();
Set<Expression> conjuncts = Streams.concat(predicates.stream(), needGenerateNotNulls.stream())
.collect(ImmutableSet.toImmutableSet());
return PlanUtils.filter(conjuncts, filter.child()).get();
}).toRule(RuleType.INFER_FILTER_NOT_NULL);
}
}

View File

@ -59,7 +59,6 @@ import java.util.Set;
public abstract class Expression extends AbstractTreeNode<Expression> implements ExpressionTrait {
public static final String DEFAULT_EXPRESSION_NAME = "expression";
// Mask this expression is generated by rule, should be removed.
public boolean isGeneratedIsNotNull = false;
protected Optional<String> exprName = Optional.empty();
private final int depth;
private final int width;

View File

@ -38,12 +38,24 @@ public class Not extends Expression implements UnaryExpression, ExpectsInputType
public static final List<DataType> EXPECTS_INPUT_TYPES = ImmutableList.of(BooleanType.INSTANCE);
private final boolean isGeneratedIsNotNull;
public Not(Expression child) {
super(ImmutableList.of(child));
this(child, false);
}
private Not(List<Expression> child) {
public Not(Expression child, boolean isGeneratedIsNotNull) {
super(ImmutableList.of(child));
this.isGeneratedIsNotNull = isGeneratedIsNotNull;
}
private Not(List<Expression> child, boolean isGeneratedIsNotNull) {
super(child);
this.isGeneratedIsNotNull = isGeneratedIsNotNull;
}
public boolean isGeneratedIsNotNull() {
return isGeneratedIsNotNull;
}
@Override
@ -70,12 +82,13 @@ public class Not extends Expression implements UnaryExpression, ExpectsInputType
return false;
}
Not other = (Not) o;
return Objects.equals(child(), other.child());
return Objects.equals(child(), other.child())
&& isGeneratedIsNotNull == other.isGeneratedIsNotNull;
}
@Override
public int hashCode() {
return child().hashCode();
return Objects.hash(child().hashCode(), isGeneratedIsNotNull);
}
@Override
@ -91,9 +104,11 @@ public class Not extends Expression implements UnaryExpression, ExpectsInputType
@Override
public Not withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
Not not = new Not(children);
not.isGeneratedIsNotNull = this.isGeneratedIsNotNull;
return not;
return new Not(children, isGeneratedIsNotNull);
}
public Not withGeneratedIsNotNull(boolean isGeneratedIsNotNull) {
return new Not(children, isGeneratedIsNotNull);
}
@Override

View File

@ -19,7 +19,9 @@ package org.apache.doris.nereids.trees.plans.commands;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.Predicate;
import org.apache.doris.analysis.SetVar;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.StringLiteral;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Database;
import org.apache.doris.catalog.Env;
@ -32,6 +34,7 @@ import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -54,9 +57,11 @@ import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.qe.VariableMgr;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Optional;
@ -90,7 +95,7 @@ public class DeleteFromCommand extends Command implements ForwardWithSync {
@Override
public void run(ConnectContext ctx, StmtExecutor executor) throws Exception {
LogicalPlanAdapter logicalPlanAdapter = new LogicalPlanAdapter(logicalQuery, ctx.getStatementContext());
turnOffForbidUnknownStats(ctx.getSessionVariable());
updateSessionVariableForDelete(ctx.getSessionVariable());
NereidsPlanner planner = new NereidsPlanner(ctx.getStatementContext());
planner.plan(logicalPlanAdapter, ctx.getSessionVariable().toThrift());
executor.setPlanner(planner);
@ -173,9 +178,22 @@ public class DeleteFromCommand extends Command implements ForwardWithSync {
Lists.newArrayList(relation.getPartNames()), predicates, ctx.getState());
}
private void turnOffForbidUnknownStats(SessionVariable sessionVariable) {
private void updateSessionVariableForDelete(SessionVariable sessionVariable) {
sessionVariable.setIsSingleSetVar(true);
sessionVariable.setForbidUnownColStats(false);
try {
// turn off forbid unknown col stats
VariableMgr.setVar(sessionVariable,
new SetVar(SessionVariable.FORBID_UNKNOWN_COLUMN_STATS, new StringLiteral("false")));
// disable eliminate not null rule
List<String> disableRules = Lists.newArrayList(
RuleType.ELIMINATE_NOT_NULL.name(), RuleType.INFER_FILTER_NOT_NULL.name());
disableRules.addAll(sessionVariable.getDisableNereidsRuleNames());
VariableMgr.setVar(sessionVariable,
new SetVar(SessionVariable.DISABLE_NEREIDS_RULES,
new StringLiteral(StringUtils.join(disableRules, ","))));
} catch (Exception e) {
throw new AnalysisException("set session variable by delete from command failed", e);
}
}
private void checkColumn(Set<String> tableColumns, SlotReference slotReference, OlapTable table) {

View File

@ -574,11 +574,8 @@ public class ExpressionUtils {
*/
public static Set<Expression> inferNotNull(Set<Expression> predicates, CascadesContext cascadesContext) {
return inferNotNullSlots(predicates, cascadesContext).stream()
.map(slot -> {
Not isNotNull = new Not(new IsNull(slot));
isNotNull.isGeneratedIsNotNull = true;
return isNotNull;
}).collect(Collectors.toSet());
.map(slot -> new Not(new IsNull(slot), false))
.collect(Collectors.toSet());
}
/**
@ -589,11 +586,8 @@ public class ExpressionUtils {
CascadesContext cascadesContext) {
return inferNotNullSlots(predicates, cascadesContext).stream()
.filter(slots::contains)
.map(slot -> {
Not isNotNull = new Not(new IsNull(slot));
isNotNull.isGeneratedIsNotNull = true;
return isNotNull;
}).collect(Collectors.toSet());
.map(slot -> new Not(new IsNull(slot), true))
.collect(Collectors.toSet());
}
public static <E extends Expression> List<E> flatExpressions(List<List<E>> expressions) {

View File

@ -2391,7 +2391,7 @@ public class SessionVariable implements Serializable, Writable {
return forbidUnknownColStats;
}
public void setForbidUnownColStats(boolean forbid) {
public void setForbidUnknownColStats(boolean forbid) {
forbidUnknownColStats = forbid;
}