[feature](nereids) support subquery in select list (#23271)

1. add scalar subquery's output to LogicalApply's output
2. for in and exists subquery's, add mark join slot into LogicalApply's output
3. forbid push down alias through join if the project list have any mark join slots.
4. move normalize aggregate rule to analysis phase
This commit is contained in:
starocean999
2023-08-31 15:51:32 +08:00
committed by GitHub
parent 62c075bf7e
commit 7379cdc995
30 changed files with 609 additions and 332 deletions

View File

@ -30,7 +30,9 @@ import org.apache.doris.nereids.rules.analysis.BindSink;
import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
import org.apache.doris.nereids.rules.analysis.CheckBound;
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate;
@ -110,9 +112,14 @@ public class Analyzer extends AbstractBatchJobExecutor {
// LogicalProject for normalize. This rule depends on FillUpMissingSlots to fill up slots.
new NormalizeRepeat()
),
bottomUp(new SubqueryToApply()),
bottomUp(new AdjustAggregateNullableForEmptySet()),
bottomUp(new CheckAnalysis())
// run CheckAnalysis before EliminateGroupByConstant in order to report error message correctly like bellow
// select SUM(lo_tax) FROM lineorder group by 1;
// errCode = 2, detailMessage = GROUP BY expression must not contain aggregate functions: sum(lo_tax)
bottomUp(new CheckAnalysis()),
topDown(new EliminateGroupByConstant()),
topDown(new NormalizeAggregate()),
bottomUp(new SubqueryToApply())
);
}
}

View File

@ -25,7 +25,9 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
@ -52,7 +54,6 @@ import org.apache.doris.nereids.rules.rewrite.EliminateAggregate;
import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition;
import org.apache.doris.nereids.rules.rewrite.EliminateEmptyRelation;
import org.apache.doris.nereids.rules.rewrite.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.rewrite.EliminateLimit;
import org.apache.doris.nereids.rules.rewrite.EliminateNotNull;
import org.apache.doris.nereids.rules.rewrite.EliminateNullAwareLeftAntiJoin;
@ -74,12 +75,12 @@ import org.apache.doris.nereids.rules.rewrite.MergeFilters;
import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion;
import org.apache.doris.nereids.rules.rewrite.MergeProjects;
import org.apache.doris.nereids.rules.rewrite.MergeSetOperations;
import org.apache.doris.nereids.rules.rewrite.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.NormalizeSort;
import org.apache.doris.nereids.rules.rewrite.PruneFileScanPartition;
import org.apache.doris.nereids.rules.rewrite.PruneOlapScanPartition;
import org.apache.doris.nereids.rules.rewrite.PruneOlapScanTablet;
import org.apache.doris.nereids.rules.rewrite.PullUpCteAnchor;
import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderApply;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan;
import org.apache.doris.nereids.rules.rewrite.PushFilterInsideJoin;
@ -139,6 +140,10 @@ public class Rewriter extends AbstractBatchJobExecutor {
),
// subquery unnesting relay on ExpressionNormalization to extract common factor expression
topic("Subquery unnesting",
// after doing NormalizeAggregate in analysis job
// we need run the following 2 rules to make AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION work
bottomUp(new PullUpProjectUnderApply()),
topDown(new PushdownFilterThroughProject()),
costBased(
custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION,
AggScalarSubQueryToWindowFunction::new)

View File

@ -15,12 +15,13 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;

View File

@ -15,10 +15,12 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;

View File

@ -21,9 +21,7 @@ import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
@ -47,7 +45,6 @@ import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -68,8 +65,8 @@ public class SubqueryToApply implements AnalysisRuleFactory {
logicalFilter().thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
ImmutableList<Set> subqueryExprsList = filter.getConjuncts().stream()
.map(e -> (Set) e.collect(SubqueryExpr.class::isInstance))
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = filter.getConjuncts().stream()
.map(e -> (Set<SubqueryExpr>) e.collect(SubqueryExpr.class::isInstance))
.collect(ImmutableList.toImmutableList());
if (subqueryExprsList.stream()
.flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) {
@ -104,8 +101,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
tmpPlan = applyPlan;
newConjuncts.add(conjunct);
}
Set<Expression> conjuncts = new LinkedHashSet<>();
conjuncts.addAll(newConjuncts.build());
Set<Expression> conjuncts = ImmutableSet.copyOf(newConjuncts.build());
Plan newFilter = new LogicalFilter<>(conjuncts, applyPlan);
if (conjuncts.stream().flatMap(c -> c.children().stream())
.anyMatch(MarkJoinSlotReference.class::isInstance)) {
@ -116,36 +112,44 @@ public class SubqueryToApply implements AnalysisRuleFactory {
return new LogicalFilter<>(conjuncts, applyPlan);
})
),
RuleType.PROJECT_SUBQUERY_TO_APPLY.build(
logicalProject().thenApply(ctx -> {
LogicalProject<Plan> project = ctx.root;
Set<SubqueryExpr> subqueryExprs = new LinkedHashSet<>();
project.getProjects().stream()
.filter(Alias.class::isInstance)
.map(Alias.class::cast)
.filter(alias -> alias.child() instanceof CaseWhen)
.forEach(alias -> alias.child().children().stream()
.forEach(e ->
subqueryExprs.addAll(e.collect(SubqueryExpr.class::isInstance))));
if (subqueryExprs.isEmpty()) {
return project;
}
RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> {
LogicalProject<Plan> project = ctx.root;
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = project.getProjects().stream()
.map(e -> (Set<SubqueryExpr>) e.collect(SubqueryExpr.class::isInstance))
.collect(ImmutableList.toImmutableList());
if (subqueryExprsList.stream().flatMap(Collection::stream).count() == 0) {
return project;
}
List<NamedExpression> oldProjects = ImmutableList.copyOf(project.getProjects());
ImmutableList.Builder<NamedExpression> newProjects = new ImmutableList.Builder<>();
LogicalPlan childPlan = (LogicalPlan) project.child();
LogicalPlan applyPlan;
for (int i = 0; i < subqueryExprsList.size(); ++i) {
Set<SubqueryExpr> subqueryExprs = subqueryExprsList.get(i);
if (subqueryExprs.isEmpty()) {
newProjects.add(oldProjects.get(i));
continue;
}
SubqueryContext context = new SubqueryContext(subqueryExprs);
return new LogicalProject(project.getProjects().stream()
.map(p -> p.withChildren(
new ReplaceSubquery(ctx.statementContext, true)
.replace(p, context)))
.collect(ImmutableList.toImmutableList()),
subqueryToApply(
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
(LogicalPlan) project.child(),
context.getSubqueryToMarkJoinSlot(),
ctx.cascadesContext,
Optional.empty(), true
));
})
)
// first step: Replace the subquery in logcialProject's project list
// second step: Replace subquery with LogicalApply
ReplaceSubquery replaceSubquery =
new ReplaceSubquery(ctx.statementContext, true);
SubqueryContext context = new SubqueryContext(subqueryExprs);
Expression newProject =
replaceSubquery.replace(oldProjects.get(i), context);
applyPlan = subqueryToApply(
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
childPlan, context.getSubqueryToMarkJoinSlot(),
ctx.cascadesContext,
Optional.of(newProject), true);
childPlan = applyPlan;
newProjects.add((NamedExpression) newProject);
}
return project.withProjectsAndChild(newProjects.build(), childPlan);
}))
);
}
@ -249,28 +253,30 @@ public class SubqueryToApply implements AnalysisRuleFactory {
// The result set when NULL is specified in the subquery and still evaluates to TRUE by using EXISTS
// When the number of rows returned is empty, agg will return null, so if there is more agg,
// it will always consider the returned result to be true
boolean needCreateMarkJoinSlot = isMarkJoin || isProject;
MarkJoinSlotReference markJoinSlotReference = null;
if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && isMarkJoin) {
if (exists.getQueryPlan().anyMatch(Aggregate.class::isInstance) && needCreateMarkJoinSlot) {
markJoinSlotReference =
new MarkJoinSlotReference(statementContext.generateColumnName(), true);
} else if (isMarkJoin) {
} else if (needCreateMarkJoinSlot) {
markJoinSlotReference =
new MarkJoinSlotReference(statementContext.generateColumnName());
}
if (isMarkJoin) {
if (needCreateMarkJoinSlot) {
context.setSubqueryToMarkJoinSlot(exists, Optional.of(markJoinSlotReference));
}
return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE;
return needCreateMarkJoinSlot ? markJoinSlotReference : BooleanLiteral.TRUE;
}
@Override
public Expression visitInSubquery(InSubquery in, SubqueryContext context) {
MarkJoinSlotReference markJoinSlotReference =
new MarkJoinSlotReference(statementContext.generateColumnName());
if (isMarkJoin) {
boolean needCreateMarkJoinSlot = isMarkJoin || isProject;
if (needCreateMarkJoinSlot) {
context.setSubqueryToMarkJoinSlot(in, Optional.of(markJoinSlotReference));
}
return isMarkJoin ? markJoinSlotReference : BooleanLiteral.TRUE;
return needCreateMarkJoinSlot ? markJoinSlotReference : BooleanLiteral.TRUE;
}
@Override

View File

@ -29,8 +29,8 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.RequireProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.rules.rewrite.NormalizeAggregate;
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
@ -45,7 +46,9 @@ public class PushdownAliasThroughJoin extends OneRewriteRuleFactory {
public Rule build() {
return logicalProject(logicalJoin())
.when(project -> project.getProjects().stream().allMatch(expr ->
(expr instanceof Slot) || (expr instanceof Alias && ((Alias) expr).child() instanceof Slot)))
(expr instanceof Slot && !(expr instanceof MarkJoinSlotReference))
|| (expr instanceof Alias && ((Alias) expr).child() instanceof Slot
&& !(((Alias) expr).child() instanceof MarkJoinSlotReference))))
.when(project -> project.getProjects().stream().anyMatch(expr -> expr instanceof Alias))
.then(project -> {
LogicalJoin<? extends Plan, ? extends Plan> join = project.child();

View File

@ -94,7 +94,7 @@ public class CaseWhen extends Expression {
StringBuilder output = new StringBuilder("CASE");
for (Expression child : children()) {
if (child instanceof WhenClause) {
output.append(child);
output.append(child.toString());
} else {
output.append(" ELSE ").append(child.toString());
}

View File

@ -58,4 +58,9 @@ public class DoubleLiteral extends Literal {
nf.setGroupingUsed(false);
return nf.format(value);
}
@Override
public String getStringValue() {
return toString();
}
}

View File

@ -22,6 +22,8 @@ import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.FloatType;
import java.text.NumberFormat;
/**
* float type literal
*/
@ -48,4 +50,11 @@ public class FloatLiteral extends Literal {
public LiteralExpr toLegacyLiteral() {
return new org.apache.doris.analysis.FloatLiteral((double) value, Type.FLOAT);
}
@Override
public String getStringValue() {
NumberFormat nf = NumberFormat.getInstance();
nf.setGroupingUsed(false);
return nf.format(value);
}
}

View File

@ -140,7 +140,7 @@ public class AnalyzeCTETest extends TestWithFeService implements MemoPatternMatc
logicalFilter(
logicalProject(
logicalJoin(
logicalAggregate(),
logicalProject(logicalAggregate()),
logicalProject()
)
)

View File

@ -156,18 +156,20 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
.matchesNotCheck(
logicalApply(
any(),
logicalAggregate(
logicalFilter()
).when(FieldChecker.check("outputExpressions", ImmutableList.of(
new Alias(new ExprId(7),
(new Sum(
new SlotReference(new ExprId(4), "k3",
BigIntType.INSTANCE, true,
ImmutableList.of(
"default_cluster:test",
"t7")))).withAlwaysNullable(
true),
"sum(k3)"))))
logicalProject(
logicalAggregate(
logicalProject()
).when(FieldChecker.check("outputExpressions", ImmutableList.of(
new Alias(new ExprId(7),
(new Sum(
new SlotReference(new ExprId(4), "k3",
BigIntType.INSTANCE, true,
ImmutableList.of(
"default_cluster:test",
"t7")))).withAlwaysNullable(
true),
"sum(k3)"))))
)
).when(FieldChecker.check("correlationSlot", ImmutableList.of(
new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t6"))
@ -383,28 +385,32 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP
logicalProject(
logicalApply(
any(),
logicalAggregate(
logicalSubQueryAlias(
logicalProject(
logicalAggregate(
logicalProject(
logicalFilter()
).when(p -> p.getProjects().equals(ImmutableList.of(
new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE,
true,
ImmutableList.of("default_cluster:test", "t7")), "aa")
)))
)
.when(a -> a.getAlias().equals("t2"))
.when(a -> a.getOutput().equals(ImmutableList.of(
new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE,
true, ImmutableList.of("t2"))
logicalSubQueryAlias(
logicalProject(
logicalFilter()
).when(p -> p.getProjects().equals(ImmutableList.of(
new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE,
true,
ImmutableList.of("default_cluster:test", "t7")), "aa")
)))
)
.when(a -> a.getAlias().equals("t2"))
.when(a -> a.getOutput().equals(ImmutableList.of(
new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE,
true, ImmutableList.of("t2"))
)))
)
).when(agg -> agg.getOutputExpressions().equals(ImmutableList.of(
new Alias(new ExprId(8),
(new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE,
true,
ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)")
)))
).when(agg -> agg.getOutputExpressions().equals(ImmutableList.of(
new Alias(new ExprId(8),
(new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE,
true,
ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)")
)))
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of()))
.when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of()))
)
)
.when(apply -> apply.getCorrelationSlot().equals(ImmutableList.of(
new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true,

View File

@ -90,10 +90,11 @@ class BindSlotReferenceTest {
join
);
PlanChecker checker = PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate);
LogicalAggregate plan = (LogicalAggregate) checker.getCascadesContext().getMemo().copyOut();
LogicalAggregate plan = (LogicalAggregate) ((LogicalProject) checker.getCascadesContext()
.getMemo().copyOut()).child();
SlotReference groupByKey = (SlotReference) plan.getGroupByExpressions().get(0);
SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child()).left().getOutput().get(0);
SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child()).right().getOutput().get(0);
SlotReference t1id = (SlotReference) ((LogicalJoin) plan.child().child(0)).left().getOutput().get(0);
SlotReference t2id = (SlotReference) ((LogicalJoin) plan.child().child(0)).right().getOutput().get(0);
Assertions.assertEquals(groupByKey.getExprId(), t1id.getExprId());
Assertions.assertNotEquals(t1id.getExprId(), t2id.getExprId());
}

View File

@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
@ -23,7 +23,6 @@ import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.PartitionInfo;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Slot;

View File

@ -35,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.FieldChecker;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
@ -45,8 +46,6 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Test;
import java.util.stream.Collectors;
public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements MemoPatternMatchSupported {
@Override
@ -86,35 +85,35 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
);
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))));
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1))))));
sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0";
a1 = new SlotReference(
new ExprId(1), "a1", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
);
Alias value = new Alias(new ExprId(3), a1, "value");
SlotReference value = new SlotReference(new ExprId(3), "value", TinyIntType.INSTANCE, true,
ImmutableList.of());
PlanChecker.from(connectContext).analyze(sql)
.applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE))
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))));
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0)))))));
sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING value > 0";
PlanChecker.from(connectContext).analyze(sql)
.applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE))
.matches(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value)))
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(value))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), new TinyIntLiteral((byte) 0))))));
sql = "SELECT SUM(a2) FROM t1 GROUP BY a1 HAVING a1 > 0";
@ -130,13 +129,14 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
PlanChecker.from(connectContext).analyze(sql)
.applyBottomUp(new ExpressionRewrite(FunctionBinder.INSTANCE))
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(sumA2, a1)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0)))))
).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot()))));
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0)))))
).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot()))));
}
@Test
@ -153,24 +153,28 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
Alias sumA2 = new Alias(new ExprId(3), new Sum(a2), "sum(a2)");
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING SUM(a2) > 0";
sumA2 = new Alias(new ExprId(3), new Sum(a2), "SUM(a2)");
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan()
)
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING SUM(a2) > 0";
a1 = new SlotReference(
@ -184,20 +188,24 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
Alias value = new Alias(new ExprId(3), new Sum(a2), "value");
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))));
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))));
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING value > 0";
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))));
sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0";
@ -217,49 +225,53 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
Alias minPK = new Alias(new ExprId(4), new Min(pk), "min(pk)");
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2) > 0";
Alias sumA1A2 = new Alias(new ExprId(3), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L)))))));
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L)))))));
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0";
Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new TinyIntLiteral((byte) 3))),
"sum(((a1 + a2) + 3))");
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));
sql = "SELECT a1 FROM t1 GROUP BY a1 HAVING COUNT(*) > 0";
Alias countStar = new Alias(new ExprId(3), new Count(), "count(*)");
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
}
@Test
@ -281,19 +293,21 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
Alias sumB1 = new Alias(new ExprId(7), new Sum(b1), "sum(b1)");
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalProject(
logicalFilter(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
)
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
sumB1.toSlot()))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
logicalProject(
logicalAggregate(
logicalProject(
logicalFilter(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
))
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1)))
)).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
sumB1.toSlot()))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
}
@Test
@ -331,6 +345,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
new ExprId(0), "pk", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
);
SlotReference pk1 = new SlotReference(
new ExprId(6), "(pk + 1)", IntegerType.INSTANCE, true,
ImmutableList.of()
);
SlotReference a1 = new SlotReference(
new ExprId(1), "a1", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
@ -339,40 +357,42 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
new ExprId(2), "a2", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
);
Alias pk1 = new Alias(new ExprId(6), new Add(pk, Literal.of((byte) 1)), "(pk + 1)");
Alias pk11 = new Alias(new ExprId(7), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
Alias pk2 = new Alias(new ExprId(8), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
Alias sumA1 = new Alias(new ExprId(9), new Sum(a1), "SUM(a1)");
Alias countA11 = new Alias(new ExprId(10), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
Alias countA1 = new Alias(new ExprId(13), new Count(a1), "count(a1)");
Alias countA11 = new Alias(new ExprId(10), new Add(countA1.toSlot(), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
Alias sumA1A2 = new Alias(new ExprId(11), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1");
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalProject(
logicalFilter(
logicalAggregate(
logicalProject(
logicalFilter(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
logicalProject(
logicalAggregate(
logicalProject(
logicalFilter(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
))
).when(FieldChecker.check("outputExpressions",
Lists.newArrayList(pk, pk1, sumA1, countA1, sumA1A2, v1))))
).when(FieldChecker.check("conjuncts",
ImmutableSet.of(
new GreaterThan(pk.toSlot(), Literal.of((byte) 0)),
new GreaterThan(countA11.toSlot(), Literal.of(0L)),
new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of(0L)),
new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of(0L)),
new GreaterThan(v1.toSlot(), Literal.of(0L))
))
)
).when(FieldChecker.check("outputExpressions",
Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk)))
).when(FieldChecker.check("conjuncts",
ImmutableSet.of(
new GreaterThan(pk.toSlot(), Literal.of((byte) 0)),
new GreaterThan(countA11.toSlot(), Literal.of(0L)),
new GreaterThan(new Add(sumA1A2.toSlot(), Literal.of((byte) 1)), Literal.of(0L)),
new GreaterThan(new Add(v1.toSlot(), Literal.of((byte) 1)), Literal.of(0L)),
new GreaterThan(v1.toSlot(), Literal.of(0L))
))
)
).when(FieldChecker.check(
"projects", Lists.newArrayList(
pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream()
.map(Alias::toSlot).collect(Collectors.toList()))
));
).when(FieldChecker.check(
"projects", Lists.newArrayList(
pk1, pk11.toSlot(), pk2.toSlot(), sumA1.toSlot(), countA11.toSlot(), sumA1A2.toSlot(), v1.toSlot())
)
));
}
@Test
@ -391,9 +411,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalSort(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
@ -402,9 +423,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalSort(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true)))));
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 ORDER BY SUM(a2)";
@ -420,9 +442,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalSort(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA2.toSlot(), true, true)))));
sql = "SELECT a1, SUM(a2) FROM t1 GROUP BY a1 ORDER BY MIN(pk)";
@ -444,9 +467,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalSort(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))))
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(minPK.toSlot(), true, true))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
@ -455,9 +479,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalSort(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))))
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A2.toSlot(), true, true)))));
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 ORDER BY SUM(a1 + a2 + 3)";
@ -467,9 +492,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalSort(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))))
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(sumA1A23.toSlot(), true, true))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));
@ -479,9 +505,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
.matches(
logicalProject(
logicalSort(
logicalAggregate(
logicalOlapScan()
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))))
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(countStar.toSlot(), true, true))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
}
@ -495,6 +522,10 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
new ExprId(0), "pk", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
);
SlotReference pk1 = new SlotReference(
new ExprId(6), "(pk + 1)", IntegerType.INSTANCE, true,
ImmutableList.of()
);
SlotReference a1 = new SlotReference(
new ExprId(1), "a1", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
@ -503,40 +534,41 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo
new ExprId(2), "a2", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_resolve_aggregate_functions", "t1")
);
Alias pk1 = new Alias(new ExprId(6), new Add(pk, Literal.of((byte) 1)), "(pk + 1)");
Alias pk11 = new Alias(new ExprId(7), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
Alias pk2 = new Alias(new ExprId(8), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
Alias sumA1 = new Alias(new ExprId(9), new Sum(a1), "SUM(a1)");
Alias countA1 = new Alias(new ExprId(13), new Count(a1), "count(a1)");
Alias countA11 = new Alias(new ExprId(10), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
Alias sumA1A2 = new Alias(new ExprId(11), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1");
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalProject(
logicalSort(
logicalAggregate(
logicalFilter(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
)
).when(FieldChecker.check("outputExpressions",
Lists.newArrayList(pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1, pk)))
).when(FieldChecker.check("orderKeys",
ImmutableList.of(
new OrderKey(pk, true, true),
new OrderKey(countA11.toSlot(), true, true),
new OrderKey(new Add(sumA1A2.toSlot(), new TinyIntLiteral((byte) 1)), true, true),
new OrderKey(new Add(v1.toSlot(), new TinyIntLiteral((byte) 1)), true, true),
new OrderKey(v1.toSlot(), true, true)
)
))
).when(FieldChecker.check(
"projects", Lists.newArrayList(
pk1, pk11, pk2, sumA1, countA11, sumA1A2, v1).stream()
.map(Alias::toSlot).collect(Collectors.toList()))
));
.matches(logicalProject(logicalSort(logicalProject(logicalAggregate(logicalProject(
logicalFilter(logicalJoin(logicalOlapScan(), logicalOlapScan())))).when(
FieldChecker.check("outputExpressions", Lists.newArrayList(pk, pk1,
sumA1, countA1, sumA1A2, v1))))).when(FieldChecker.check(
"orderKeys",
ImmutableList.of(new OrderKey(pk, true, true),
new OrderKey(
countA11.toSlot(), true, true),
new OrderKey(
new Add(sumA1A2.toSlot(),
new TinyIntLiteral(
(byte) 1)),
true, true),
new OrderKey(
new Add(v1.toSlot(),
new TinyIntLiteral(
(byte) 1)),
true, true),
new OrderKey(v1.toSlot(), true, true)))))
.when(FieldChecker.check("projects",
Lists.newArrayList(pk1,
pk11.toSlot(),
pk2.toSlot(),
sumA1.toSlot(),
countA11.toSlot(),
sumA1A2.toSlot(),
v1.toSlot()))));
}
@Test

View File

@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.implementation.AggregateStrategies;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.AggregateExpression;

View File

@ -299,15 +299,16 @@ public class ColumnPruningTest extends TestWithFeService implements MemoPatternM
.matches(
logicalProject(
logicalSubQueryAlias(
logicalAggregate(
logicalProject(
logicalOlapScan()
).when(p -> getOutputQualifiedNames(p).equals(
ImmutableList.of("default_cluster:test.student.id")
))
).when(agg -> getOutputQualifiedNames(agg.getOutputs()).equals(
ImmutableList.of("default_cluster:test.student.id")
))
logicalAggregate(
logicalProject(
logicalOlapScan()
).when(p -> getOutputQualifiedNames(p).equals(
ImmutableList.of("default_cluster:test.student.id")
))
).when(agg -> getOutputQualifiedNames(agg.getOutputs()).equals(
ImmutableList.of("default_cluster:test.student.id")
)))
)
)
);

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;

View File

@ -18,6 +18,10 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@ -30,6 +34,8 @@ import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
import java.util.List;
class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported {
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
@ -99,4 +105,21 @@ class PushdownAliasThroughJoinTest implements MemoPatternMatchSupported {
&& project.getProjects().get(1).toSql().equals("2name"))
);
}
@Test
void testNoPushdownMarkJoin() {
List<NamedExpression> projects =
ImmutableList.of(new MarkJoinSlotReference(new ExprId(101), "markSlot1", false),
new Alias(new MarkJoinSlotReference(new ExprId(102), "markSlot2", false),
"markSlot2"));
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)).projectExprs(projects).build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushdownAliasThroughJoin())
.matches(logicalProject(logicalJoin(logicalOlapScan(), logicalOlapScan()))
.when(project -> project.getProjects().get(0).toSql().equals("markSlot1")
&& project.getProjects().get(1).toSql()
.equals("markSlot2 AS `markSlot2`")));
}
}

View File

@ -135,20 +135,22 @@ public class PushdownExpressionsInHashConditionTest extends TestWithFeService im
.applyTopDown(new FindHashConditionForJoin())
.applyTopDown(new PushdownExpressionsInHashCondition())
.matches(
logicalProject(
logicalJoin(
logicalProject(
logicalOlapScan()
),
logicalProject(
logicalSubQueryAlias(
logicalAggregate(
logicalOlapScan()
)
logicalProject(
logicalJoin(
logicalProject(
logicalOlapScan()
),
logicalProject(
logicalSubQueryAlias(
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan()
)))
)
)
)
)
)
)
);
}
@ -168,8 +170,12 @@ public class PushdownExpressionsInHashConditionTest extends TestWithFeService im
logicalProject(
logicalSubQueryAlias(
logicalSort(
logicalAggregate(
logicalOlapScan()
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan()
)
)
)
)
)

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite.mv;
import org.apache.doris.common.FeConstants;
import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
import org.apache.doris.nereids.rules.rewrite.MergeProjects;
import org.apache.doris.nereids.rules.rewrite.PushdownFilterThroughProject;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
@ -188,7 +189,8 @@ class SelectRollupIndexTest extends BaseMaterializedIndexSelectTest implements M
PlanChecker.from(connectContext)
.analyze(sql)
.applyBottomUp(new LogicalSubQueryAliasToLogicalProject())
.applyTopDown(new MergeProjects())
.applyTopDown(new PushdownFilterThroughProject())
.applyBottomUp(new MergeProjects())
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
.matches(logicalOlapScan().when(scan -> {

View File

@ -0,0 +1,50 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !sql1 --
3
-- !sql2 --
3
-- !sql3 --
3
-- !sql4 --
false
-- !sql5 --
false
-- !sql6 --
true
-- !sql7 --
2
-- !sql8 --
4
4
-- !sql9 --
4
4
-- !sql10 --
false
true
-- !sql11 --
false
true
-- !sql12 --
true
true
-- !sql13 --
2
2
-- !sql14 --
\N 2.0
2020-09-09 2.0

View File

@ -23,7 +23,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------------PhysicalProject
------------------PhysicalOlapScan[customer]
--------------PhysicalDistribute
----------------hashJoin[INNER_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE))
----------------hashJoin[INNER_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE))
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN](store.s_store_sk = ctr1.ctr_store_sk)
----------------------PhysicalDistribute
@ -32,11 +32,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------------------------PhysicalProject
--------------------------filter((cast(s_state as VARCHAR(*)) = 'SD'))
----------------------------PhysicalOlapScan[store]
------------------PhysicalProject
--------------------hashAgg[GLOBAL]
----------------------PhysicalDistribute
------------------------hashAgg[LOCAL]
--------------------------PhysicalDistribute
----------------------------PhysicalProject
------------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
------------------hashAgg[GLOBAL]
--------------------PhysicalDistribute
----------------------hashAgg[LOCAL]
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------PhysicalCteConsumer ( cteId=CTEId#0 )

View File

@ -24,7 +24,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------PhysicalDistribute
--------PhysicalTopN
----------PhysicalProject
------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE))
------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE))
--------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk)
----------------PhysicalDistribute
------------------PhysicalCteConsumer ( cteId=CTEId#0 )
@ -38,11 +38,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------------------------filter((cast(ca_state as VARCHAR(*)) = 'IN'))
----------------------------PhysicalOlapScan[customer_address]
--------------PhysicalDistribute
----------------PhysicalProject
------------------hashAgg[GLOBAL]
--------------------PhysicalDistribute
----------------------hashAgg[LOCAL]
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
----------------hashAgg[GLOBAL]
------------------PhysicalDistribute
--------------------hashAgg[LOCAL]
----------------------PhysicalDistribute
------------------------PhysicalProject
--------------------------PhysicalCteConsumer ( cteId=CTEId#0 )

View File

@ -14,30 +14,32 @@ PhysicalResultSink
----------------------PhysicalWindow
------------------------PhysicalQuickSort
--------------------------PhysicalDistribute
----------------------------hashAgg[GLOBAL]
------------------------------PhysicalDistribute
--------------------------------hashAgg[LOCAL]
----------------------------------PhysicalProject
------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk = date_dim.d_date_sk)
--------------------------------------PhysicalProject
----------------------------------------PhysicalOlapScan[store_sales]
--------------------------------------PhysicalDistribute
----------------------------PhysicalProject
------------------------------hashAgg[GLOBAL]
--------------------------------PhysicalDistribute
----------------------------------hashAgg[LOCAL]
------------------------------------PhysicalProject
--------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk = date_dim.d_date_sk)
----------------------------------------PhysicalProject
------------------------------------------filter((date_dim.d_month_seq <= 1227)(date_dim.d_month_seq >= 1216))
--------------------------------------------PhysicalOlapScan[date_dim]
------------------------------------------PhysicalOlapScan[store_sales]
----------------------------------------PhysicalDistribute
------------------------------------------PhysicalProject
--------------------------------------------filter((date_dim.d_month_seq <= 1227)(date_dim.d_month_seq >= 1216))
----------------------------------------------PhysicalOlapScan[date_dim]
--------------------PhysicalProject
----------------------PhysicalWindow
------------------------PhysicalQuickSort
--------------------------PhysicalDistribute
----------------------------hashAgg[GLOBAL]
------------------------------PhysicalDistribute
--------------------------------hashAgg[LOCAL]
----------------------------------PhysicalProject
------------------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = date_dim.d_date_sk)
--------------------------------------PhysicalProject
----------------------------------------PhysicalOlapScan[web_sales]
--------------------------------------PhysicalDistribute
----------------------------PhysicalProject
------------------------------hashAgg[GLOBAL]
--------------------------------PhysicalDistribute
----------------------------------hashAgg[LOCAL]
------------------------------------PhysicalProject
--------------------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = date_dim.d_date_sk)
----------------------------------------PhysicalProject
------------------------------------------filter((date_dim.d_month_seq >= 1216)(date_dim.d_month_seq <= 1227))
--------------------------------------------PhysicalOlapScan[date_dim]
------------------------------------------PhysicalOlapScan[web_sales]
----------------------------------------PhysicalDistribute
------------------------------------------PhysicalProject
--------------------------------------------filter((date_dim.d_month_seq >= 1216)(date_dim.d_month_seq <= 1227))
----------------------------------------------PhysicalOlapScan[date_dim]

View File

@ -24,7 +24,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------PhysicalDistribute
--------PhysicalTopN
----------PhysicalProject
------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE))
------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2) as DOUBLE))
--------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk)
----------------PhysicalDistribute
------------------PhysicalCteConsumer ( cteId=CTEId#0 )
@ -38,11 +38,10 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------------------------filter((cast(ca_state as VARCHAR(*)) = 'CA'))
----------------------------PhysicalOlapScan[customer_address]
--------------PhysicalDistribute
----------------PhysicalProject
------------------hashAgg[GLOBAL]
--------------------PhysicalDistribute
----------------------hashAgg[LOCAL]
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
----------------hashAgg[GLOBAL]
------------------PhysicalDistribute
--------------------hashAgg[LOCAL]
----------------------PhysicalDistribute
------------------------PhysicalProject
--------------------------PhysicalCteConsumer ( cteId=CTEId#0 )

View File

@ -9,13 +9,12 @@ PhysicalResultSink
------------PhysicalDistribute
--------------PhysicalProject
----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity)))
------------------PhysicalProject
--------------------hashAgg[GLOBAL]
----------------------PhysicalDistribute
------------------------hashAgg[LOCAL]
--------------------------PhysicalProject
----------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01))
------------------------------PhysicalOlapScan[lineitem]
------------------hashAgg[GLOBAL]
--------------------PhysicalDistribute
----------------------hashAgg[LOCAL]
------------------------PhysicalProject
--------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01))
----------------------------PhysicalOlapScan[lineitem]
------------------PhysicalDistribute
--------------------hashJoin[LEFT_SEMI_JOIN](partsupp.ps_partkey = part.p_partkey)
----------------------PhysicalProject

View File

@ -9,13 +9,12 @@ PhysicalResultSink
------------PhysicalDistribute
--------------PhysicalProject
----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity)))
------------------PhysicalProject
--------------------hashAgg[GLOBAL]
----------------------PhysicalDistribute
------------------------hashAgg[LOCAL]
--------------------------PhysicalProject
----------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01))
------------------------------PhysicalOlapScan[lineitem]
------------------hashAgg[GLOBAL]
--------------------PhysicalDistribute
----------------------hashAgg[LOCAL]
------------------------PhysicalProject
--------------------------filter((lineitem.l_shipdate < 1995-01-01)(lineitem.l_shipdate >= 1994-01-01))
----------------------------PhysicalOlapScan[lineitem]
------------------PhysicalDistribute
--------------------hashJoin[LEFT_SEMI_JOIN](partsupp.ps_partkey = part.p_partkey)
----------------------PhysicalProject

View File

@ -0,0 +1,120 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("test_subquery_in_project") {
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"
sql """drop table if exists test_sql;"""
sql """
CREATE TABLE `test_sql` (
`user_id` varchar(10) NULL,
`dt` date NULL,
`city` varchar(20) NULL,
`age` int(11) NULL
) ENGINE=OLAP
UNIQUE KEY(`user_id`)
COMMENT 'test'
DISTRIBUTED BY HASH(`user_id`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"is_being_synced" = "false",
"storage_format" = "V2",
"light_schema_change" = "true",
"disable_auto_compaction" = "false",
"enable_single_replica_compaction" = "false"
);
"""
sql """ insert into test_sql values (1,'2020-09-09',2,3);"""
qt_sql1 """
select (select age from test_sql) col from test_sql order by col;
"""
qt_sql2 """
select (select sum(age) from test_sql) col from test_sql order by col;
"""
qt_sql3 """
select (select sum(age) from test_sql t2 where t2.dt = t1.dt ) col from test_sql t1 order by col;
"""
qt_sql4 """
select age in (select user_id from test_sql) col from test_sql order by col;
"""
qt_sql5 """
select age in (select user_id from test_sql t2 where t2.user_id = t1.age) col from test_sql t1 order by col;
"""
qt_sql6 """
select exists ( select user_id from test_sql ) col from test_sql order by col;
"""
qt_sql7 """
select case when age in (select user_id from test_sql) or age in (select user_id from test_sql t2 where t2.user_id = t1.age) or exists ( select user_id from test_sql ) or exists ( select t2.user_id from test_sql t2 where t2.age = t1.user_id) or age < (select sum(age) from test_sql t2 where t2.dt = t1.dt ) then 2 else 1 end col from test_sql t1 order by col;
"""
sql """ insert into test_sql values (2,'2020-09-09',2,1);"""
try {
sql """
select (select age from test_sql) col from test_sql order by col;
"""
} catch (Exception ex) {
assertTrue(ex.getMessage().contains("Expected EQ 1 to be returned by expression"))
}
qt_sql8 """
select (select sum(age) from test_sql) col from test_sql order by col;
"""
qt_sql9 """
select (select sum(age) from test_sql t2 where t2.dt = t1.dt ) col from test_sql t1 order by col;
"""
qt_sql10 """
select age in (select user_id from test_sql) col from test_sql order by col;
"""
qt_sql11 """
select age in (select user_id from test_sql t2 where t2.user_id = t1.age) col from test_sql t1 order by col;
"""
qt_sql12 """
select exists ( select user_id from test_sql ) col from test_sql order by col;
"""
qt_sql13 """
select case when age in (select user_id from test_sql) or age in (select user_id from test_sql t2 where t2.user_id = t1.age) or exists ( select user_id from test_sql ) or exists ( select t2.user_id from test_sql t2 where t2.age = t1.user_id) or age < (select sum(age) from test_sql t2 where t2.dt = t1.dt ) then 2 else 1 end col from test_sql t1 order by col;
"""
qt_sql14 """
select dt,case when 'med'='med' then (
select sum(midean) from (
select sum(score) / count(*) as midean
from (
select age score,row_number() over (order by age desc) as desc_math,
row_number() over (order by age asc) as asc_math from test_sql
) as order_table
where asc_math in (desc_math, desc_math + 1, desc_math - 1)) m
)
end 'test' from test_sql group by cube(dt) order by dt;
"""
sql """drop table if exists test_sql;"""
}