diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashCondition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashCondition.java index 6f1d6e68b0..bc92d8f37c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashCondition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownExpressionsInHashCondition.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -33,7 +34,6 @@ import org.apache.doris.nereids.util.JoinUtils; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -42,6 +42,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; /** * push down expression which is not slot reference @@ -70,7 +71,7 @@ public class PushdownExpressionsInHashCondition extends OneRewriteRuleFactory { .then(join -> { List> exprsOfHashConjuncts = Lists.newArrayList(Lists.newArrayList(), Lists.newArrayList()); - Map exprMap = Maps.newHashMap(); + Map exprMap = Maps.newHashMap(); join.getHashJoinConjuncts().forEach(conjunct -> { Preconditions.checkArgument(conjunct instanceof EqualTo); // sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it @@ -79,8 +80,13 @@ public class PushdownExpressionsInHashCondition extends OneRewriteRuleFactory { (EqualTo) conjunct, join.left().getOutputSet()); exprsOfHashConjuncts.get(0).add(conjunct.child(0)); exprsOfHashConjuncts.get(1).add(conjunct.child(1)); - conjunct.children().forEach(expr -> - exprMap.put(expr, new Alias(expr, "expr_" + expr.toSql()))); + conjunct.children().forEach(expr -> { + if ((expr instanceof SlotReference)) { + exprMap.put(expr, (SlotReference) expr); + } else { + exprMap.put(expr, new Alias(expr, "expr_" + expr.toSql())); + } + }); }); Iterator> iter = exprsOfHashConjuncts.iterator(); return join.withHashJoinConjunctsAndChildren( @@ -90,10 +96,16 @@ public class PushdownExpressionsInHashCondition extends OneRewriteRuleFactory { .collect(ImmutableList.toImmutableList()))) .collect(ImmutableList.toImmutableList()), join.children().stream().map( - plan -> new LogicalProject<>(new Builder() - .addAll(iter.next().stream().map(exprMap::get) - .collect(ImmutableList.toImmutableList())) - .addAll(getOutput(plan, join)).build(), plan)) + plan -> { + Set projectSet = Sets.newHashSet(); + projectSet.addAll(iter.next().stream().map(exprMap::get) + .collect(Collectors.toList())); + projectSet.addAll(getOutput(plan, join)); + List projectList = projectSet.stream() + .collect(ImmutableList.toImmutableList()); + return new LogicalProject<>(projectList, plan); + } + ) .collect(ImmutableList.toImmutableList())); }).toRule(RuleType.PUSHDOWN_EXPRESSIONS_IN_HASH_CONDITIONS); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index 1d9e223890..bcb9e83a68 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -666,6 +666,7 @@ public class StatsCalculator extends DefaultPlanVisitor } return Pair.of(expr.toSlot().getExprId(), value); }).collect(Collectors.toMap(Pair::key, Pair::value)); + columnStatisticMap.putAll(childColumnStats); return new StatsDeriveResult(stats.getRowCount(), columnStatisticMap); } }