[feature](nereids) support subquery in select list (#23271)
1. add scalar subquery's output to LogicalApply's output 2. for in and exists subquery's, add mark join slot into LogicalApply's output 3. forbid push down alias through join if the project list have any mark join slots. 4. move normalize aggregate rule to analysis phase
This commit is contained in:
@ -30,7 +30,9 @@ import org.apache.doris.nereids.rules.analysis.BindSink;
|
||||
import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
|
||||
import org.apache.doris.nereids.rules.analysis.CheckBound;
|
||||
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
|
||||
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
|
||||
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
|
||||
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
|
||||
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
|
||||
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
|
||||
import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate;
|
||||
@ -110,9 +112,14 @@ public class Analyzer extends AbstractBatchJobExecutor {
|
||||
// LogicalProject for normalize. This rule depends on FillUpMissingSlots to fill up slots.
|
||||
new NormalizeRepeat()
|
||||
),
|
||||
bottomUp(new SubqueryToApply()),
|
||||
bottomUp(new AdjustAggregateNullableForEmptySet()),
|
||||
bottomUp(new CheckAnalysis())
|
||||
// run CheckAnalysis before EliminateGroupByConstant in order to report error message correctly like bellow
|
||||
// select SUM(lo_tax) FROM lineorder group by 1;
|
||||
// errCode = 2, detailMessage = GROUP BY expression must not contain aggregate functions: sum(lo_tax)
|
||||
bottomUp(new CheckAnalysis()),
|
||||
topDown(new EliminateGroupByConstant()),
|
||||
topDown(new NormalizeAggregate()),
|
||||
bottomUp(new SubqueryToApply())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,7 +25,9 @@ import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
|
||||
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
|
||||
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
|
||||
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
|
||||
import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
|
||||
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
|
||||
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
|
||||
@ -52,7 +54,6 @@ import org.apache.doris.nereids.rules.rewrite.EliminateAggregate;
|
||||
import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition;
|
||||
import org.apache.doris.nereids.rules.rewrite.EliminateEmptyRelation;
|
||||
import org.apache.doris.nereids.rules.rewrite.EliminateFilter;
|
||||
import org.apache.doris.nereids.rules.rewrite.EliminateGroupByConstant;
|
||||
import org.apache.doris.nereids.rules.rewrite.EliminateLimit;
|
||||
import org.apache.doris.nereids.rules.rewrite.EliminateNotNull;
|
||||
import org.apache.doris.nereids.rules.rewrite.EliminateNullAwareLeftAntiJoin;
|
||||
@ -74,12 +75,12 @@ import org.apache.doris.nereids.rules.rewrite.MergeFilters;
|
||||
import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion;
|
||||
import org.apache.doris.nereids.rules.rewrite.MergeProjects;
|
||||
import org.apache.doris.nereids.rules.rewrite.MergeSetOperations;
|
||||
import org.apache.doris.nereids.rules.rewrite.NormalizeAggregate;
|
||||
import org.apache.doris.nereids.rules.rewrite.NormalizeSort;
|
||||
import org.apache.doris.nereids.rules.rewrite.PruneFileScanPartition;
|
||||
import org.apache.doris.nereids.rules.rewrite.PruneOlapScanPartition;
|
||||
import org.apache.doris.nereids.rules.rewrite.PruneOlapScanTablet;
|
||||
import org.apache.doris.nereids.rules.rewrite.PullUpCteAnchor;
|
||||
import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderApply;
|
||||
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan;
|
||||
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan;
|
||||
import org.apache.doris.nereids.rules.rewrite.PushFilterInsideJoin;
|
||||
@ -139,6 +140,10 @@ public class Rewriter extends AbstractBatchJobExecutor {
|
||||
),
|
||||
// subquery unnesting relay on ExpressionNormalization to extract common factor expression
|
||||
topic("Subquery unnesting",
|
||||
// after doing NormalizeAggregate in analysis job
|
||||
// we need run the following 2 rules to make AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION work
|
||||
bottomUp(new PullUpProjectUnderApply()),
|
||||
topDown(new PushdownFilterThroughProject()),
|
||||
costBased(
|
||||
custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION,
|
||||
AggScalarSubQueryToWindowFunction::new)
|
||||
|
||||
@ -15,12 +15,13 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
@ -15,10 +15,12 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
@ -21,9 +21,7 @@ import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
|
||||
import org.apache.doris.nereids.trees.expressions.CaseWhen;
|
||||
import org.apache.doris.nereids.trees.expressions.Exists;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InSubquery;
|
||||
@ -47,7 +45,6 @@ import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@ -68,8 +65,8 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
logicalFilter().thenApply(ctx -> {
|
||||
LogicalFilter<Plan> filter = ctx.root;
|
||||
|
||||
ImmutableList<Set> subqueryExprsList = filter.getConjuncts().stream()
|
||||
.map(e -> (Set) e.collect(SubqueryExpr.class::isInstance))
|
||||
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = filter.getConjuncts().stream()
|
||||
.map(e -> (Set<SubqueryExpr>) e.collect(SubqueryExpr.class::isInstance))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
if (subqueryExprsList.stream()
|
||||
.flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) {
|
||||
@ -104,8 +101,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
tmpPlan = applyPlan;
|
||||
newConjuncts.add(conjunct);
|
||||
}
|
||||
Set<Expression> conjuncts = new LinkedHashSet<>();
|
||||
conjuncts.addAll(newConjuncts.build());
|
||||
Set<Expression> conjuncts = ImmutableSet.copyOf(newConjuncts.build());
|
||||
Plan newFilter = new LogicalFilter<>(conjuncts, applyPlan);
|
||||
if (conjuncts.stream().flatMap(c -> c.children().stream())
|
||||
.anyMatch(MarkJoinSlotReference.class::isInstance)) {
|
||||
@ -116,36 +112,44 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
return new LogicalFilter<>(conjuncts, applyPlan);
|
||||
})
|
||||
),
|
||||
RuleType.PROJECT_SUBQUERY_TO_APPLY.build(
|
||||
logicalProject().thenApply(ctx -> {
|
||||
LogicalProject<Plan> project = ctx.root;
|
||||
Set<SubqueryExpr> subqueryExprs = new LinkedHashSet<>();
|
||||
project.getProjects().stream()
|
||||
.filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast)
|
||||
.filter(alias -> alias.child() instanceof CaseWhen)
|
||||
.forEach(alias -> alias.child().children().stream()
|
||||
.forEach(e ->
|
||||
subqueryExprs.addAll(e.collect(SubqueryExpr.class::isInstance))));
|
||||
if (subqueryExprs.isEmpty()) {
|
||||
return project;
|
||||
}
|
||||
RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> {
|
||||
LogicalProject<Plan> project = ctx.root;
|
||||
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = project.getProjects().stream()
|
||||
.map(e -> (Set<SubqueryExpr>) e.collect(SubqueryExpr.class::isInstance))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
if (subqueryExprsList.stream().flatMap(Collection::stream).count() == 0) {
|
||||
return project;
|
||||
}
|
||||
List<NamedExpression> oldProjects = ImmutableList.copyOf(project.getProjects());
|
||||
ImmutableList.Builder<NamedExpression> newProjects = new ImmutableList.Builder<>();
|
||||
LogicalPlan childPlan = (LogicalPlan) project.child();
|
||||
LogicalPlan applyPlan;
|
||||
for (int i = 0; i < subqueryExprsList.size(); ++i) {
|
||||
Set<SubqueryExpr> subqueryExprs = subqueryExprsList.get(i);
|
||||
if (subqueryExprs.isEmpty()) {
|
||||
newProjects.add(oldProjects.get(i));
|
||||
continue;
|
||||
}
|
||||
|
||||
SubqueryContext context = new SubqueryContext(subqueryExprs);
|
||||
return new LogicalProject(project.getProjects().stream()
|
||||
.map(p -> p.withChildren(
|
||||
new ReplaceSubquery(ctx.statementContext, true)
|
||||
.replace(p, context)))
|
||||
.collect(ImmutableList.toImmutableList()),
|
||||
subqueryToApply(
|
||||
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
|
||||
(LogicalPlan) project.child(),
|
||||
context.getSubqueryToMarkJoinSlot(),
|
||||
ctx.cascadesContext,
|
||||
Optional.empty(), true
|
||||
));
|
||||
})
|
||||
)
|
||||
// first step: Replace the subquery in logcialProject's project list
|
||||
// second step: Replace subquery with LogicalApply
|
||||
ReplaceSubquery replaceSubquery =
|
||||
new ReplaceSubquery(ctx.statementContext, true);
|
||||
SubqueryContext context = new SubqueryContext(subqueryExprs);
|
||||
Expression newProject =
|
||||
replaceSubquery.replace(oldProjects.get(i), context);
|
||||
|
||||
applyPlan = subqueryToApply(
|
||||
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
|
||||
childPlan, context.getSubqueryToMarkJoinSlot(),
|
||||
ctx.cascadesContext,
|
||||
Optional.of(newProject), true);
|
||||
childPlan = applyPlan;
|
||||
newProjects.add((NamedExpression) newProject);
|
||||
}
|
||||
|
||||
return project.withProjectsAndChild(newProjects.build(), childPlan);
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
@ -249,28 +253,30 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
// The result set when NULL is specified in the subquery and still evaluates to TRUE by using EXISTS
|
||||
// When the number of rows returned is empty, agg will return null, so if there is more agg,
|
||||
// it will always consider the returned result to be true
|
||||
boolean needCreateMarkJoinSlot = isMarkJoin || isProject;
|
||||
MarkJoinSlotReference markJoinSlotReference = null;
|
||||
if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && isMarkJoin) {
|
||||
if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && needCreateMarkJoinSlot) {
|
||||
markJoinSlotReference =
|
||||
new MarkJoinSlotReference(statementContext.generateColumnName(), true);
|
||||
} else if (isMarkJoin) {
|
||||
} else if (needCreateMarkJoinSlot) {
|
||||
markJoinSlotReference =
|
||||
new MarkJoinSlotReference(statementContext.generateColumnName());
|
||||
}
|
||||
if (isMarkJoin) {
|
||||
if (needCreateMarkJoinSlot) {
|
||||
context.setSubqueryToMarkJoinSlot(exists, Optional.of(markJoinSlotReference));
|
||||
}
|
||||
return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE;
|
||||
return needCreateMarkJoinSlot ? markJoinSlotReference : BooleanLiteral.TRUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitInSubquery(InSubquery in, SubqueryContext context) {
|
||||
MarkJoinSlotReference markJoinSlotReference =
|
||||
new MarkJoinSlotReference(statementContext.generateColumnName());
|
||||
if (isMarkJoin) {
|
||||
boolean needCreateMarkJoinSlot = isMarkJoin || isProject;
|
||||
if (needCreateMarkJoinSlot) {
|
||||
context.setSubqueryToMarkJoinSlot(in, Optional.of(markJoinSlotReference));
|
||||
}
|
||||
return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE;
|
||||
return needCreateMarkJoinSlot ? markJoinSlotReference : BooleanLiteral.TRUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -29,8 +29,8 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.properties.RequireProperties;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
|
||||
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
|
||||
import org.apache.doris.nereids.rules.rewrite.NormalizeAggregate;
|
||||
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
|
||||
@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
@ -45,7 +46,9 @@ public class PushdownAliasThroughJoin extends OneRewriteRuleFactory {
|
||||
public Rule build() {
|
||||
return logicalProject(logicalJoin())
|
||||
.when(project -> project.getProjects().stream().allMatch(expr ->
|
||||
(expr instanceof Slot) || (expr instanceof Alias && ((Alias) expr).child() instanceof Slot)))
|
||||
(expr instanceof Slot && !(expr instanceof MarkJoinSlotReference))
|
||||
|| (expr instanceof Alias && ((Alias) expr).child() instanceof Slot
|
||||
&& !(((Alias) expr).child() instanceof MarkJoinSlotReference))))
|
||||
.when(project -> project.getProjects().stream().anyMatch(expr -> expr instanceof Alias))
|
||||
.then(project -> {
|
||||
LogicalJoin<? extends Plan, ? extends Plan> join = project.child();
|
||||
|
||||
@ -94,7 +94,7 @@ public class CaseWhen extends Expression {
|
||||
StringBuilder output = new StringBuilder("CASE");
|
||||
for (Expression child : children()) {
|
||||
if (child instanceof WhenClause) {
|
||||
output.append(child);
|
||||
output.append(child.toString());
|
||||
} else {
|
||||
output.append(" ELSE ").append(child.toString());
|
||||
}
|
||||
|
||||
@ -58,4 +58,9 @@ public class DoubleLiteral extends Literal {
|
||||
nf.setGroupingUsed(false);
|
||||
return nf.format(value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getStringValue() {
|
||||
return toString();
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,6 +22,8 @@ import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.FloatType;
|
||||
|
||||
import java.text.NumberFormat;
|
||||
|
||||
/**
|
||||
* float type literal
|
||||
*/
|
||||
@ -48,4 +50,11 @@ public class FloatLiteral extends Literal {
|
||||
public LiteralExpr toLegacyLiteral() {
|
||||
return new org.apache.doris.analysis.FloatLiteral((double) value, Type.FLOAT);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getStringValue() {
|
||||
NumberFormat nf = NumberFormat.getInstance();
|
||||
nf.setGroupingUsed(false);
|
||||
return nf.format(value);
|
||||
}
|
||||
}
|
||||
|
||||
@ -140,7 +140,7 @@ public class AnalyzeCTETest extends TestWithFeService implements MemoPatternMatc
|
||||
logicalFilter(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalAggregate(),
|
||||
logicalProject(logicalAggregate()),
|
||||
logicalProject()
|
||||
)
|
||||
)
|
||||
|
||||
@ -156,18 +156,20 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
|
||||
.matchesNotCheck(
|
||||
logicalApply(
|
||||
any(),
|
||||
logicalAggregate(
|
||||
logicalFilter()
|
||||
).when(FieldChecker.check("outputExpressions", ImmutableList.of(
|
||||
new Alias(new ExprId(7),
|
||||
(new Sum(
|
||||
new SlotReference(new ExprId(4), "k3",
|
||||
BigIntType.INSTANCE, true,
|
||||
ImmutableList.of(
|
||||
"default_cluster:test",
|
||||
"t7")))).withAlwaysNullable(
|
||||
true),
|
||||
"sum(k3)"))))
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject()
|
||||
).when(FieldChecker.check("outputExpressions", ImmutableList.of(
|
||||
new Alias(new ExprId(7),
|
||||
(new Sum(
|
||||
new SlotReference(new ExprId(4), "k3",
|
||||
BigIntType.INSTANCE, true,
|
||||
ImmutableList.of(
|
||||
"default_cluster:test",
|
||||
"t7")))).withAlwaysNullable(
|
||||
true),
|
||||
"sum(k3)"))))
|
||||
)
|
||||
).when(FieldChecker.check("correlationSlot", ImmutableList.of(
|
||||
new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true,
|
||||
ImmutableList.of("default_cluster:test", "t6"))
|
||||
@ -383,28 +385,32 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
|
||||
logicalProject(
|
||||
logicalApply(
|
||||
any(),
|
||||
logicalAggregate(
|
||||
logicalSubQueryAlias(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalFilter()
|
||||
).when(p -> p.getProjects().equals(ImmutableList.of(
|
||||
new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE,
|
||||
true,
|
||||
ImmutableList.of("default_cluster:test", "t7")), "aa")
|
||||
)))
|
||||
)
|
||||
.when(a -> a.getAlias().equals("t2"))
|
||||
.when(a -> a.getOutput().equals(ImmutableList.of(
|
||||
new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE,
|
||||
true, ImmutableList.of("t2"))
|
||||
logicalSubQueryAlias(
|
||||
logicalProject(
|
||||
logicalFilter()
|
||||
).when(p -> p.getProjects().equals(ImmutableList.of(
|
||||
new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE,
|
||||
true,
|
||||
ImmutableList.of("default_cluster:test", "t7")), "aa")
|
||||
)))
|
||||
)
|
||||
.when(a -> a.getAlias().equals("t2"))
|
||||
.when(a -> a.getOutput().equals(ImmutableList.of(
|
||||
new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE,
|
||||
true, ImmutableList.of("t2"))
|
||||
)))
|
||||
)
|
||||
).when(agg -> agg.getOutputExpressions().equals(ImmutableList.of(
|
||||
new Alias(new ExprId(8),
|
||||
(new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE,
|
||||
true,
|
||||
ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)")
|
||||
)))
|
||||
).when(agg -> agg.getOutputExpressions().equals(ImmutableList.of(
|
||||
new Alias(new ExprId(8),
|
||||
(new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE,
|
||||
true,
|
||||
ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)")
|
||||
)))
|
||||
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of()))
|
||||
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of()))
|
||||
)
|
||||
)
|
||||
.when(apply -> apply.getCorrelationSlot().equals(ImmutableList.of(
|
||||
new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true,
|
||||
|
||||
@ -90,10 +90,11 @@ class BindSlotReferenceTest {
|
||||
join
|
||||
);
|
||||
PlanChecker checker = PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate);
|
||||
LogicalAggregate plan = (LogicalAggregate) checker.getCascadesContext().getMemo().copyOut();
|
||||
LogicalAggregate plan = (LogicalAggregate) ((LogicalProject) checker.getCascadesContext()
|
||||
.getMemo().copyOut()).child();
|
||||
SlotReference groupByKey = (SlotReference) plan.getGroupByExpressions().get(0);
|
||||
SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child()).left().getOutput().get(0);
|
||||
SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child()).right().getOutput().get(0);
|
||||
SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child().child(0)).left().getOutput().get(0);
|
||||
SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child().child(0)).right().getOutput().get(0);
|
||||
Assertions.assertEquals(groupByKey.getExprId(), t1id.getExprId());
|
||||
Assertions.assertNotEquals(t1id.getExprId(), t2id.getExprId());
|
||||
}
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.catalog.AggregateType;
|
||||
import org.apache.doris.catalog.Column;
|
||||
@ -23,7 +23,6 @@ import org.apache.doris.catalog.KeysType;
|
||||
import org.apache.doris.catalog.OlapTable;
|
||||
import org.apache.doris.catalog.PartitionInfo;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
@ -35,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.types.TinyIntType;
|
||||
import org.apache.doris.nereids.util.FieldChecker;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
@ -45,8 +46,6 @@ import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements MemoPatternMatchSupported {
|
||||
|
||||
@Override
|
||||
@ -86,35 +85,35 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
);
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))));
|
||||
logicalFilter(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1))))));
|
||||
|
||||
sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0";
|
||||
a1 = new SlotReference(
|
||||
new ExprId(1), "a1", TinyIntType.INSTANCE, true,
|
||||
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
|
||||
);
|
||||
Alias value = new Alias(new ExprId(3), a1, "value");
|
||||
SlotReference value = new SlotReference(new ExprId(3), "value", TinyIntType.INSTANCE, true,
|
||||
ImmutableList.of());
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE))
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))));
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))));
|
||||
|
||||
sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING value > 0";
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE))
|
||||
.matches(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))));
|
||||
|
||||
sql = "SELECT SUM(a2) FROM t1 GROUP BY a1 HAVING a1 > 0";
|
||||
@ -130,13 +129,14 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE))
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(sumA2, a1)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0)))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot()))));
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0)))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot()))));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -153,24 +153,28 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
Alias sumA2 = new Alias(new ExprId(3), new Sum(a2), "sum(a2)");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
|
||||
|
||||
sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING SUM(a2) > 0";
|
||||
sumA2 = new Alias(new ExprId(3), new Sum(a2), "SUM(a2)");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalOlapScan()
|
||||
)
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));
|
||||
|
||||
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING SUM(a2) > 0";
|
||||
a1 = new SlotReference(
|
||||
@ -184,20 +188,24 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
Alias value = new Alias(new ExprId(3), new Sum(a2), "value");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))));
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))));
|
||||
|
||||
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING value > 0";
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))));
|
||||
|
||||
sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0";
|
||||
@ -217,49 +225,53 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
Alias minPK = new Alias(new ExprId(4), new Min(pk), "min(pk)");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0)))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0)))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
|
||||
|
||||
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2) > 0";
|
||||
Alias sumA1A2 = new Alias(new ExprId(3), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L)))))));
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L)))))));
|
||||
|
||||
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0";
|
||||
Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new TinyIntLiteral((byte) 3))),
|
||||
"sum(((a1 + a2) + 3))");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L)))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L)))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));
|
||||
|
||||
sql = "SELECT a1 FROM t1 GROUP BY a1 HAVING COUNT(*) > 0";
|
||||
Alias countStar = new Alias(new ExprId(3), new Count(), "count(*)");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L)))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L)))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -281,19 +293,21 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
Alias sumB1 = new Alias(new ExprId(7), new Sum(b1), "sum(b1)");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalJoin(
|
||||
logicalOlapScan(),
|
||||
logicalOlapScan()
|
||||
)
|
||||
)
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1)))
|
||||
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
|
||||
sumB1.toSlot()))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalJoin(
|
||||
logicalOlapScan(),
|
||||
logicalOlapScan()
|
||||
)
|
||||
))
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1)))
|
||||
)).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
|
||||
sumB1.toSlot()))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -331,6 +345,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
new ExprId(0), "pk", TinyIntType.INSTANCE, true,
|
||||
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
|
||||
);
|
||||
SlotReference pk1 = new SlotReference(
|
||||
new ExprId(6), "(pk + 1)", IntegerType.INSTANCE, true,
|
||||
ImmutableList.of()
|
||||
);
|
||||
SlotReference a1 = new SlotReference(
|
||||
new ExprId(1), "a1", TinyIntType.INSTANCE, true,
|
||||
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
|
||||
@ -339,40 +357,42 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
new ExprId(2), "a2", TinyIntType.INSTANCE, true,
|
||||
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
|
||||
);
|
||||
Alias pk1 = new Alias(new ExprId(6), new Add(pk, Literal.of((byte) 1)), "(pk + 1)");
|
||||
Alias pk11 = new Alias(new ExprId(7), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
|
||||
Alias pk2 = new Alias(new ExprId(8), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
|
||||
Alias sumA1 = new Alias(new ExprId(9), new Sum(a1), "SUM(a1)");
|
||||
Alias countA11 = new Alias(new ExprId(10), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
|
||||
Alias countA1 = new Alias(new ExprId(13), new Count(a1), "count(a1)");
|
||||
Alias countA11 = new Alias(new ExprId(10), new Add(countA1.toSlot(), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
|
||||
Alias sumA1A2 = new Alias(new ExprId(11), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
|
||||
Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalJoin(
|
||||
logicalOlapScan(),
|
||||
logicalOlapScan()
|
||||
)
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalJoin(
|
||||
logicalOlapScan(),
|
||||
logicalOlapScan()
|
||||
)
|
||||
))
|
||||
).when(FieldChecker.check("outputExpressions",
|
||||
Lists.newArrayList(pk, pk1, sumA1, countA1, sumA1A2, v1))))
|
||||
).when(FieldChecker.check("conjuncts",
|
||||
ImmutableSet.of(
|
||||
new GreaterThan(pk.toSlot(), Literal.of((byte) 0)),
|
||||
new GreaterThan(countA11.toSlot(), Literal.of(0L)),
|
||||
new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of(0L)),
|
||||
new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of(0L)),
|
||||
new GreaterThan(v1.toSlot(), Literal.of(0L))
|
||||
))
|
||||
)
|
||||
).when(FieldChecker.check("outputExpressions",
|
||||
Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk)))
|
||||
).when(FieldChecker.check("conjuncts",
|
||||
ImmutableSet.of(
|
||||
new GreaterThan(pk.toSlot(), Literal.of((byte) 0)),
|
||||
new GreaterThan(countA11.toSlot(), Literal.of(0L)),
|
||||
new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of(0L)),
|
||||
new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of(0L)),
|
||||
new GreaterThan(v1.toSlot(), Literal.of(0L))
|
||||
))
|
||||
)
|
||||
).when(FieldChecker.check(
|
||||
"projects", Lists.newArrayList(
|
||||
pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream()
|
||||
.map(Alias::toSlot).collect(Collectors.toList()))
|
||||
));
|
||||
).when(FieldChecker.check(
|
||||
"projects", Lists.newArrayList(
|
||||
pk1, pk11.toSlot(), pk2.toSlot(), sumA1.toSlot(), countA11.toSlot(), sumA1A2.toSlot(), v1.toSlot())
|
||||
)
|
||||
));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -391,9 +411,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
|
||||
|
||||
@ -402,9 +423,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true)))));
|
||||
|
||||
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 ORDER BY SUM(a2)";
|
||||
@ -420,9 +442,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true)))));
|
||||
|
||||
sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 ORDER BY MIN(pk)";
|
||||
@ -444,9 +467,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(minPK.toSlot(), true, true))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
|
||||
|
||||
@ -455,9 +479,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A2.toSlot(), true, true)))));
|
||||
|
||||
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 ORDER BY SUM(a1 + a2 + 3)";
|
||||
@ -467,9 +492,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A23.toSlot(), true, true))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));
|
||||
|
||||
@ -479,9 +505,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(logicalOlapScan())
|
||||
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))))
|
||||
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(countStar.toSlot(), true, true))))
|
||||
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
|
||||
}
|
||||
@ -495,6 +522,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
new ExprId(0), "pk", TinyIntType.INSTANCE, true,
|
||||
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
|
||||
);
|
||||
SlotReference pk1 = new SlotReference(
|
||||
new ExprId(6), "(pk + 1)", IntegerType.INSTANCE, true,
|
||||
ImmutableList.of()
|
||||
);
|
||||
SlotReference a1 = new SlotReference(
|
||||
new ExprId(1), "a1", TinyIntType.INSTANCE, true,
|
||||
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
|
||||
@ -503,40 +534,41 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
|
||||
new ExprId(2), "a2", TinyIntType.INSTANCE, true,
|
||||
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
|
||||
);
|
||||
Alias pk1 = new Alias(new ExprId(6), new Add(pk, Literal.of((byte) 1)), "(pk + 1)");
|
||||
Alias pk11 = new Alias(new ExprId(7), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
|
||||
Alias pk2 = new Alias(new ExprId(8), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
|
||||
Alias sumA1 = new Alias(new ExprId(9), new Sum(a1), "SUM(a1)");
|
||||
Alias countA1 = new Alias(new ExprId(13), new Count(a1), "count(a1)");
|
||||
Alias countA11 = new Alias(new ExprId(10), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
|
||||
Alias sumA1A2 = new Alias(new ExprId(11), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
|
||||
Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalFilter(
|
||||
logicalJoin(
|
||||
logicalOlapScan(),
|
||||
logicalOlapScan()
|
||||
)
|
||||
)
|
||||
).when(FieldChecker.check("outputExpressions",
|
||||
Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk)))
|
||||
).when(FieldChecker.check("orderKeys",
|
||||
ImmutableList.of(
|
||||
new OrderKey(pk, true, true),
|
||||
new OrderKey(countA11.toSlot(), true, true),
|
||||
new OrderKey(new Add(sumA1A2.toSlot(), new TinyIntLiteral((byte) 1)), true, true),
|
||||
new OrderKey(new Add(v1.toSlot(), new TinyIntLiteral((byte) 1)), true, true),
|
||||
new OrderKey(v1.toSlot(), true, true)
|
||||
)
|
||||
))
|
||||
).when(FieldChecker.check(
|
||||
"projects", Lists.newArrayList(
|
||||
pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream()
|
||||
.map(Alias::toSlot).collect(Collectors.toList()))
|
||||
));
|
||||
.matches(logicalProject(logicalSort(logicalProject(logicalAggregate(logicalProject(
|
||||
logicalFilter(logicalJoin(logicalOlapScan(), logicalOlapScan())))).when(
|
||||
FieldChecker.check("outputExpressions", Lists.newArrayList(pk, pk1,
|
||||
sumA1, countA1, sumA1A2, v1))))).when(FieldChecker.check(
|
||||
"orderKeys",
|
||||
ImmutableList.of(new OrderKey(pk, true, true),
|
||||
new OrderKey(
|
||||
countA11.toSlot(), true, true),
|
||||
new OrderKey(
|
||||
new Add(sumA1A2.toSlot(),
|
||||
new TinyIntLiteral(
|
||||
(byte) 1)),
|
||||
true, true),
|
||||
new OrderKey(
|
||||
new Add(v1.toSlot(),
|
||||
new TinyIntLiteral(
|
||||
(byte) 1)),
|
||||
true, true),
|
||||
new OrderKey(v1.toSlot(), true, true)))))
|
||||
.when(FieldChecker.check("projects",
|
||||
Lists.newArrayList(pk1,
|
||||
pk11.toSlot(),
|
||||
pk2.toSlot(),
|
||||
sumA1.toSlot(),
|
||||
countA11.toSlot(),
|
||||
sumA1A2.toSlot(),
|
||||
v1.toSlot()))));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite;
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
|
||||
import org.apache.doris.nereids.rules.implementation.AggregateStrategies;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
|
||||
|
||||
@ -299,15 +299,16 @@ public class ColumnPruningTest extends TestWithFeService implements MemoPatternM
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalSubQueryAlias(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalOlapScan()
|
||||
).when(p -> getOutputQualifiedNames(p).equals(
|
||||
ImmutableList.of("default_cluster:test.student.id")
|
||||
))
|
||||
).when(agg -> getOutputQualifiedNames(agg.getOutputs()).equals(
|
||||
ImmutableList.of("default_cluster:test.student.id")
|
||||
))
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalOlapScan()
|
||||
).when(p -> getOutputQualifiedNames(p).equals(
|
||||
ImmutableList.of("default_cluster:test.student.id")
|
||||
))
|
||||
).when(agg -> getOutputQualifiedNames(agg.getOutputs()).equals(
|
||||
ImmutableList.of("default_cluster:test.student.id")
|
||||
)))
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.properties.OrderKey;
|
||||
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
@ -18,6 +18,10 @@
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
@ -30,6 +34,8 @@ import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported {
|
||||
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
@ -99,4 +105,21 @@ class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported {
|
||||
&& project.getProjects().get(1).toSql().equals("2name"))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNoPushdownMarkJoin() {
|
||||
List<NamedExpression> projects =
|
||||
ImmutableList.of(new MarkJoinSlotReference(new ExprId(101), "markSlot1", false),
|
||||
new Alias(new MarkJoinSlotReference(new ExprId(102), "markSlot2", false),
|
||||
"markSlot2"));
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)).projectExprs(projects).build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new PushdownAliasThroughJoin())
|
||||
.matches(logicalProject(logicalJoin(logicalOlapScan(), logicalOlapScan()))
|
||||
.when(project -> project.getProjects().get(0).toSql().equals("markSlot1")
|
||||
&& project.getProjects().get(1).toSql()
|
||||
.equals("markSlot2 AS `markSlot2`")));
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,20 +135,22 @@ public class PushdownExpressionsInHashConditionTest extends TestWithFeService im
|
||||
.applyTopDown(new FindHashConditionForJoin())
|
||||
.applyTopDown(new PushdownExpressionsInHashCondition())
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalProject(
|
||||
logicalOlapScan()
|
||||
),
|
||||
logicalProject(
|
||||
logicalSubQueryAlias(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
)
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalProject(
|
||||
logicalOlapScan()
|
||||
),
|
||||
logicalProject(
|
||||
logicalSubQueryAlias(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalOlapScan()
|
||||
)))
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@ -168,8 +170,12 @@ public class PushdownExpressionsInHashConditionTest extends TestWithFeService im
|
||||
logicalProject(
|
||||
logicalSubQueryAlias(
|
||||
logicalSort(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalProject(
|
||||
logicalOlapScan()
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite.mv;
|
||||
import org.apache.doris.common.FeConstants;
|
||||
import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
|
||||
import org.apache.doris.nereids.rules.rewrite.MergeProjects;
|
||||
import org.apache.doris.nereids.rules.rewrite.PushdownFilterThroughProject;
|
||||
import org.apache.doris.nereids.trees.plans.PreAggStatus;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
@ -188,7 +189,8 @@ class SelectRollupIndexTest extends BaseMaterializedIndexSelectTest implements M
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.applyBottomUp(new LogicalSubQueryAliasToLogicalProject())
|
||||
.applyTopDown(new MergeProjects())
|
||||
.applyTopDown(new PushdownFilterThroughProject())
|
||||
.applyBottomUp(new MergeProjects())
|
||||
.applyTopDown(new SelectMaterializedIndexWithAggregate())
|
||||
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
|
||||
.matches(logicalOlapScan().when(scan -> {
|
||||
|
||||
@ -0,0 +1,50 @@
|
||||
-- This file is automatically generated. You should know what you did if you want to edit this
|
||||
-- !sql1 --
|
||||
3
|
||||
|
||||
-- !sql2 --
|
||||
3
|
||||
|
||||
-- !sql3 --
|
||||
3
|
||||
|
||||
-- !sql4 --
|
||||
false
|
||||
|
||||
-- !sql5 --
|
||||
false
|
||||
|
||||
-- !sql6 --
|
||||
true
|
||||
|
||||
-- !sql7 --
|
||||
2
|
||||
|
||||
-- !sql8 --
|
||||
4
|
||||
4
|
||||
|
||||
-- !sql9 --
|
||||
4
|
||||
4
|
||||
|
||||
-- !sql10 --
|
||||
false
|
||||
true
|
||||
|
||||
-- !sql11 --
|
||||
false
|
||||
true
|
||||
|
||||
-- !sql12 --
|
||||
true
|
||||
true
|
||||
|
||||
-- !sql13 --
|
||||
2
|
||||
2
|
||||
|
||||
-- !sql14 --
|
||||
\N 2.0
|
||||
2020-09-09 2.0
|
||||
|
||||
@ -23,7 +23,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
|
||||
----------------PhysicalProject
|
||||
------------------PhysicalOlapScan[customer]
|
||||
--------------PhysicalDistribute
|
||||
----------------hashJoin[INNER_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE))
|
||||
----------------hashJoin[INNER_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE))
|
||||
------------------PhysicalProject
|
||||
--------------------hashJoin[INNER_JOIN](store.s_store_sk = ctr1.ctr_store_sk)
|
||||
----------------------PhysicalDistribute
|
||||
@ -32,11 +32,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
|
||||
------------------------PhysicalProject
|
||||
--------------------------filter((cast(s_state as VARCHAR(*)) = 'SD'))
|
||||
----------------------------PhysicalOlapScan[store]
|
||||
------------------PhysicalProject
|
||||
--------------------hashAgg[GLOBAL]
|
||||
----------------------PhysicalDistribute
|
||||
------------------------hashAgg[LOCAL]
|
||||
--------------------------PhysicalDistribute
|
||||
----------------------------PhysicalProject
|
||||
------------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
|
||||
------------------hashAgg[GLOBAL]
|
||||
--------------------PhysicalDistribute
|
||||
----------------------hashAgg[LOCAL]
|
||||
------------------------PhysicalDistribute
|
||||
--------------------------PhysicalProject
|
||||
----------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
|
||||
------PhysicalDistribute
|
||||
--------PhysicalTopN
|
||||
----------PhysicalProject
|
||||
------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE))
|
||||
------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE))
|
||||
--------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk)
|
||||
----------------PhysicalDistribute
|
||||
------------------PhysicalCteConsumer ( cteId=CTEId#0 )
|
||||
@ -38,11 +38,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
|
||||
--------------------------filter((cast(ca_state as VARCHAR(*)) = 'IN'))
|
||||
----------------------------PhysicalOlapScan[customer_address]
|
||||
--------------PhysicalDistribute
|
||||
----------------PhysicalProject
|
||||
------------------hashAgg[GLOBAL]
|
||||
--------------------PhysicalDistribute
|
||||
----------------------hashAgg[LOCAL]
|
||||
------------------------PhysicalDistribute
|
||||
--------------------------PhysicalProject
|
||||
----------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
|
||||
----------------hashAgg[GLOBAL]
|
||||
------------------PhysicalDistribute
|
||||
--------------------hashAgg[LOCAL]
|
||||
----------------------PhysicalDistribute
|
||||
------------------------PhysicalProject
|
||||
--------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
|
||||
|
||||
|
||||
@ -14,30 +14,32 @@ PhysicalResultSink
|
||||
----------------------PhysicalWindow
|
||||
------------------------PhysicalQuickSort
|
||||
--------------------------PhysicalDistribute
|
||||
----------------------------hashAgg[GLOBAL]
|
||||
------------------------------PhysicalDistribute
|
||||
--------------------------------hashAgg[LOCAL]
|
||||
----------------------------------PhysicalProject
|
||||
------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk = date_dim.d_date_sk)
|
||||
--------------------------------------PhysicalProject
|
||||
----------------------------------------PhysicalOlapScan[store_sales]
|
||||
--------------------------------------PhysicalDistribute
|
||||
----------------------------PhysicalProject
|
||||
------------------------------hashAgg[GLOBAL]
|
||||
--------------------------------PhysicalDistribute
|
||||
----------------------------------hashAgg[LOCAL]
|
||||
------------------------------------PhysicalProject
|
||||
--------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk = date_dim.d_date_sk)
|
||||
----------------------------------------PhysicalProject
|
||||
------------------------------------------filter((date_dim.d_month_seq <= 1227)(date_dim.d_month_seq >= 1216))
|
||||
--------------------------------------------PhysicalOlapScan[date_dim]
|
||||
------------------------------------------PhysicalOlapScan[store_sales]
|
||||
----------------------------------------PhysicalDistribute
|
||||
------------------------------------------PhysicalProject
|
||||
--------------------------------------------filter((date_dim.d_month_seq <= 1227)(date_dim.d_month_seq >= 1216))
|
||||
----------------------------------------------PhysicalOlapScan[date_dim]
|
||||
--------------------PhysicalProject
|
||||
----------------------PhysicalWindow
|
||||
------------------------PhysicalQuickSort
|
||||
--------------------------PhysicalDistribute
|
||||
----------------------------hashAgg[GLOBAL]
|
||||
------------------------------PhysicalDistribute
|
||||
--------------------------------hashAgg[LOCAL]
|
||||
----------------------------------PhysicalProject
|
||||
------------------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = date_dim.d_date_sk)
|
||||
--------------------------------------PhysicalProject
|
||||
----------------------------------------PhysicalOlapScan[web_sales]
|
||||
--------------------------------------PhysicalDistribute
|
||||
----------------------------PhysicalProject
|
||||
------------------------------hashAgg[GLOBAL]
|
||||
--------------------------------PhysicalDistribute
|
||||
----------------------------------hashAgg[LOCAL]
|
||||
------------------------------------PhysicalProject
|
||||
--------------------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = date_dim.d_date_sk)
|
||||
----------------------------------------PhysicalProject
|
||||
------------------------------------------filter((date_dim.d_month_seq >= 1216)(date_dim.d_month_seq <= 1227))
|
||||
--------------------------------------------PhysicalOlapScan[date_dim]
|
||||
------------------------------------------PhysicalOlapScan[web_sales]
|
||||
----------------------------------------PhysicalDistribute
|
||||
------------------------------------------PhysicalProject
|
||||
--------------------------------------------filter((date_dim.d_month_seq >= 1216)(date_dim.d_month_seq <= 1227))
|
||||
----------------------------------------------PhysicalOlapScan[date_dim]
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
|
||||
------PhysicalDistribute
|
||||
--------PhysicalTopN
|
||||
----------PhysicalProject
|
||||
------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE))
|
||||
------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE))
|
||||
--------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk)
|
||||
----------------PhysicalDistribute
|
||||
------------------PhysicalCteConsumer ( cteId=CTEId#0 )
|
||||
@ -38,11 +38,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
|
||||
--------------------------filter((cast(ca_state as VARCHAR(*)) = 'CA'))
|
||||
----------------------------PhysicalOlapScan[customer_address]
|
||||
--------------PhysicalDistribute
|
||||
----------------PhysicalProject
|
||||
------------------hashAgg[GLOBAL]
|
||||
--------------------PhysicalDistribute
|
||||
----------------------hashAgg[LOCAL]
|
||||
------------------------PhysicalDistribute
|
||||
--------------------------PhysicalProject
|
||||
----------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
|
||||
----------------hashAgg[GLOBAL]
|
||||
------------------PhysicalDistribute
|
||||
--------------------hashAgg[LOCAL]
|
||||
----------------------PhysicalDistribute
|
||||
------------------------PhysicalProject
|
||||
--------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
|
||||
|
||||
|
||||
@ -9,13 +9,12 @@ PhysicalResultSink
|
||||
------------PhysicalDistribute
|
||||
--------------PhysicalProject
|
||||
----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity)))
|
||||
------------------PhysicalProject
|
||||
--------------------hashAgg[GLOBAL]
|
||||
----------------------PhysicalDistribute
|
||||
------------------------hashAgg[LOCAL]
|
||||
--------------------------PhysicalProject
|
||||
----------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01))
|
||||
------------------------------PhysicalOlapScan[lineitem]
|
||||
------------------hashAgg[GLOBAL]
|
||||
--------------------PhysicalDistribute
|
||||
----------------------hashAgg[LOCAL]
|
||||
------------------------PhysicalProject
|
||||
--------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01))
|
||||
----------------------------PhysicalOlapScan[lineitem]
|
||||
------------------PhysicalDistribute
|
||||
--------------------hashJoin[LEFT_SEMI_JOIN](partsupp.ps_partkey = part.p_partkey)
|
||||
----------------------PhysicalProject
|
||||
|
||||
@ -9,13 +9,12 @@ PhysicalResultSink
|
||||
------------PhysicalDistribute
|
||||
--------------PhysicalProject
|
||||
----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity)))
|
||||
------------------PhysicalProject
|
||||
--------------------hashAgg[GLOBAL]
|
||||
----------------------PhysicalDistribute
|
||||
------------------------hashAgg[LOCAL]
|
||||
--------------------------PhysicalProject
|
||||
----------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01))
|
||||
------------------------------PhysicalOlapScan[lineitem]
|
||||
------------------hashAgg[GLOBAL]
|
||||
--------------------PhysicalDistribute
|
||||
----------------------hashAgg[LOCAL]
|
||||
------------------------PhysicalProject
|
||||
--------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01))
|
||||
----------------------------PhysicalOlapScan[lineitem]
|
||||
------------------PhysicalDistribute
|
||||
--------------------hashJoin[LEFT_SEMI_JOIN](partsupp.ps_partkey = part.p_partkey)
|
||||
----------------------PhysicalProject
|
||||
|
||||
@ -0,0 +1,120 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
suite("test_subquery_in_project") {
|
||||
sql "SET enable_nereids_planner=true"
|
||||
sql "SET enable_fallback_to_original_planner=false"
|
||||
sql """drop table if exists test_sql;"""
|
||||
sql """
|
||||
CREATE TABLE `test_sql` (
|
||||
`user_id` varchar(10) NULL,
|
||||
`dt` date NULL,
|
||||
`city` varchar(20) NULL,
|
||||
`age` int(11) NULL
|
||||
) ENGINE=OLAP
|
||||
UNIQUE KEY(`user_id`)
|
||||
COMMENT 'test'
|
||||
DISTRIBUTED BY HASH(`user_id`) BUCKETS 1
|
||||
PROPERTIES (
|
||||
"replication_allocation" = "tag.location.default: 1",
|
||||
"is_being_synced" = "false",
|
||||
"storage_format" = "V2",
|
||||
"light_schema_change" = "true",
|
||||
"disable_auto_compaction" = "false",
|
||||
"enable_single_replica_compaction" = "false"
|
||||
);
|
||||
"""
|
||||
|
||||
sql """ insert into test_sql values (1,'2020-09-09',2,3);"""
|
||||
|
||||
qt_sql1 """
|
||||
select (select age from test_sql) col from test_sql order by col;
|
||||
"""
|
||||
|
||||
qt_sql2 """
|
||||
select (select sum(age) from test_sql) col from test_sql order by col;
|
||||
"""
|
||||
|
||||
qt_sql3 """
|
||||
select (select sum(age) from test_sql t2 where t2.dt = t1.dt ) col from test_sql t1 order by col;
|
||||
"""
|
||||
|
||||
qt_sql4 """
|
||||
select age in (select user_id from test_sql) col from test_sql order by col;
|
||||
"""
|
||||
|
||||
qt_sql5 """
|
||||
select age in (select user_id from test_sql t2 where t2.user_id = t1.age) col from test_sql t1 order by col;
|
||||
"""
|
||||
|
||||
qt_sql6 """
|
||||
select exists ( select user_id from test_sql ) col from test_sql order by col;
|
||||
"""
|
||||
|
||||
qt_sql7 """
|
||||
select case when age in (select user_id from test_sql) or age in (select user_id from test_sql t2 where t2.user_id = t1.age) or exists ( select user_id from test_sql ) or exists ( select t2.user_id from test_sql t2 where t2.age = t1.user_id) or age < (select sum(age) from test_sql t2 where t2.dt = t1.dt ) then 2 else 1 end col from test_sql t1 order by col;
|
||||
"""
|
||||
|
||||
sql """ insert into test_sql values (2,'2020-09-09',2,1);"""
|
||||
|
||||
try {
|
||||
sql """
|
||||
select (select age from test_sql) col from test_sql order by col;
|
||||
"""
|
||||
} catch (Exception ex) {
|
||||
assertTrue(ex.getMessage().contains("Expected EQ 1 to be returned by expression"))
|
||||
}
|
||||
|
||||
qt_sql8 """
|
||||
select (select sum(age) from test_sql) col from test_sql order by col;
|
||||
"""
|
||||
|
||||
qt_sql9 """
|
||||
select (select sum(age) from test_sql t2 where t2.dt = t1.dt ) col from test_sql t1 order by col;
|
||||
"""
|
||||
|
||||
qt_sql10 """
|
||||
select age in (select user_id from test_sql) col from test_sql order by col;
|
||||
"""
|
||||
|
||||
qt_sql11 """
|
||||
select age in (select user_id from test_sql t2 where t2.user_id = t1.age) col from test_sql t1 order by col;
|
||||
"""
|
||||
|
||||
qt_sql12 """
|
||||
select exists ( select user_id from test_sql ) col from test_sql order by col;
|
||||
"""
|
||||
|
||||
qt_sql13 """
|
||||
select case when age in (select user_id from test_sql) or age in (select user_id from test_sql t2 where t2.user_id = t1.age) or exists ( select user_id from test_sql ) or exists ( select t2.user_id from test_sql t2 where t2.age = t1.user_id) or age < (select sum(age) from test_sql t2 where t2.dt = t1.dt ) then 2 else 1 end col from test_sql t1 order by col;
|
||||
"""
|
||||
|
||||
qt_sql14 """
|
||||
select dt,case when 'med'='med' then (
|
||||
select sum(midean) from (
|
||||
select sum(score) / count(*) as midean
|
||||
from (
|
||||
select age score,row_number() over (order by age desc) as desc_math,
|
||||
row_number() over (order by age asc) as asc_math from test_sql
|
||||
) as order_table
|
||||
where asc_math in (desc_math, desc_math + 1, desc_math - 1)) m
|
||||
)
|
||||
end 'test' from test_sql group by cube(dt) order by dt;
|
||||
"""
|
||||
|
||||
sql """drop table if exists test_sql;"""
|
||||
}
|
||||
Reference in New Issue
Block a user