From 27d6794b81250d460fbbe7403543f76c548e4893 Mon Sep 17 00:00:00 2001 From: EmmyMiao87 <522274284@qq.com> Date: Thu, 5 Dec 2019 21:27:33 +0800 Subject: [PATCH] Support subquery with non-scalar result in Binary predicate and Between-and predicate (#2360) This commit add a new plan node named AssertNumRowsNode which is used to determine whether the number of rows exceeds the limit. The subquery in Binary predicate and Between-and predicate should be added a AssertNumRowsNode which is used to determine whether the number of rows in subquery is more than 1. If the number of rows in subquery is more than 1, the query will be cancelled. For example: There are 4 rows in table t1. Query: select c1 from t1 where c1=(select c2 from t1); Result: ERROR 1064 (HY000): Expected no more than 1 to be returned by expression select c2 from t1 ISSUE-2270 TPC-DS 6,54,58 --- be/src/exec/CMakeLists.txt | 1 + be/src/exec/assert_num_rows_node.cpp | 75 +++++++++++++++ be/src/exec/assert_num_rows_node.h | 41 ++++++++ be/src/exec/exec_node.cpp | 5 + .../doris/analysis/AssertNumRowsElement.java | 36 +++++++ .../doris/analysis/BetweenPredicate.java | 6 ++ .../doris/analysis/BinaryPredicate.java | 18 ++-- .../doris/analysis/FunctionCallExpr.java | 3 + .../org/apache/doris/analysis/QueryStmt.java | 16 +++- .../org/apache/doris/analysis/SelectStmt.java | 2 +- .../apache/doris/analysis/StmtRewriter.java | 64 ++++++++++--- .../doris/load/loadv2/BrokerLoadJob.java | 2 +- .../doris/planner/AssertNumRowsNode.java | 62 ++++++++++++ .../doris/planner/DistributedPlanner.java | 22 +++++ .../apache/doris/planner/HashJoinNode.java | 1 - .../doris/planner/SingleNodePlanner.java | 12 +++ .../doris/analysis/BetweenPredicateTest.java | 54 +++++++++++ .../doris/analysis/BinaryPredicateTest.java | 71 ++++++++++++++ .../doris/planner/DistributedPlannerTest.java | 94 +++++++++++++++++++ gensrc/thrift/PlanNodes.thrift | 9 +- 20 files changed, 565 insertions(+), 29 deletions(-) create mode 100644 be/src/exec/assert_num_rows_node.cpp create mode 100644 be/src/exec/assert_num_rows_node.h create mode 100644 fe/src/main/java/org/apache/doris/analysis/AssertNumRowsElement.java create mode 100644 fe/src/main/java/org/apache/doris/planner/AssertNumRowsNode.java create mode 100644 fe/src/test/java/org/apache/doris/analysis/BetweenPredicateTest.java create mode 100644 fe/src/test/java/org/apache/doris/analysis/BinaryPredicateTest.java create mode 100644 fe/src/test/java/org/apache/doris/planner/DistributedPlannerTest.java diff --git a/be/src/exec/CMakeLists.txt b/be/src/exec/CMakeLists.txt index ac8927aa7f..5decdf4fcf 100644 --- a/be/src/exec/CMakeLists.txt +++ b/be/src/exec/CMakeLists.txt @@ -93,6 +93,7 @@ set(EXEC_FILES broker_writer.cpp parquet_scanner.cpp parquet_reader.cpp + assert_num_rows_node.cpp ) if (WITH_MYSQL) diff --git a/be/src/exec/assert_num_rows_node.cpp b/be/src/exec/assert_num_rows_node.cpp new file mode 100644 index 0000000000..26112f65d1 --- /dev/null +++ b/be/src/exec/assert_num_rows_node.cpp @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exec/assert_num_rows_node.h" + +#include "runtime/row_batch.h" +#include "runtime/runtime_state.h" +#include "util/runtime_profile.h" +#include "gen_cpp/PlanNodes_types.h" +#include "gutil/strings/substitute.h" + +namespace doris { + +AssertNumRowsNode::AssertNumRowsNode( + ObjectPool* pool, const TPlanNode& tnode, const DescriptorTbl& descs) : + ExecNode(pool, tnode, descs), + _desired_num_rows(tnode.assert_num_rows_node.desired_num_rows), + _subquery_string(tnode.assert_num_rows_node.subquery_string) { +} + +Status AssertNumRowsNode::init(const TPlanNode& tnode, RuntimeState* state) { + RETURN_IF_ERROR(ExecNode::init(tnode, state)); + return Status::OK(); +} + +Status AssertNumRowsNode::prepare(RuntimeState* state) { + RETURN_IF_ERROR(ExecNode::prepare(state)); + return Status::OK(); +} + +Status AssertNumRowsNode::open(RuntimeState* state) { + SCOPED_TIMER(_runtime_profile->total_time_counter()); + RETURN_IF_ERROR(ExecNode::open(state)); + return Status::OK(); +} + +Status AssertNumRowsNode::get_next(RuntimeState* state, RowBatch* output_batch, bool* eos) { + RETURN_IF_ERROR(exec_debug_action(TExecNodePhase::GETNEXT)); + SCOPED_TIMER(_runtime_profile->total_time_counter()); + output_batch->reset(); + child(0)->get_next(state, output_batch, eos); + _num_rows_returned += output_batch->num_rows(); + if (_num_rows_returned > _desired_num_rows) { + LOG(INFO) << "Expected no more than " << _desired_num_rows << " to be returned by expression " + << _subquery_string; + return Status::Cancelled(strings::Substitute( + "Expected no more than $0 to be returned by expression $1", + _desired_num_rows, _subquery_string)); + } + COUNTER_SET(_rows_returned_counter, _num_rows_returned); + return Status::OK(); +} + +Status AssertNumRowsNode::close(RuntimeState* state) { + if (is_closed()) { + return Status::OK(); + } + return ExecNode::close(state); +} + +} \ No newline at end of file diff --git a/be/src/exec/assert_num_rows_node.h b/be/src/exec/assert_num_rows_node.h new file mode 100644 index 0000000000..0ed5ddcf32 --- /dev/null +++ b/be/src/exec/assert_num_rows_node.h @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exec/exec_node.h" +#include "gen_cpp/PlanNodes_types.h" + +namespace doris { + +// Node for assert row count: +// - +class AssertNumRowsNode : public ExecNode { +public: + AssertNumRowsNode(ObjectPool* pool, const TPlanNode& tnode, const DescriptorTbl& descs); + virtual ~AssertNumRowsNode() {}; + + virtual Status init(const TPlanNode& tnode, RuntimeState* state = nullptr); + virtual Status prepare(RuntimeState* state); + virtual Status open(RuntimeState* state); + virtual Status get_next(RuntimeState* state, RowBatch* row_batch, bool* eos); + virtual Status close(RuntimeState* state); + +private: + int64_t _desired_num_rows; + const std::string _subquery_string; +}; + +} \ No newline at end of file diff --git a/be/src/exec/exec_node.cpp b/be/src/exec/exec_node.cpp index ab0059645d..634a073c9c 100644 --- a/be/src/exec/exec_node.cpp +++ b/be/src/exec/exec_node.cpp @@ -50,6 +50,7 @@ #include "exec/analytic_eval_node.h" #include "exec/select_node.h" #include "exec/union_node.h" +#include "exec/assert_num_rows_node.h" #include "runtime/exec_env.h" #include "runtime/descriptors.h" #include "runtime/initial_reservations.h" @@ -451,6 +452,10 @@ Status ExecNode::create_node(RuntimeState* state, ObjectPool* pool, const TPlanN *node = pool->add(new BrokerScanNode(pool, tnode, descs)); return Status::OK(); + case TPlanNodeType::ASSERT_NUM_ROWS_NODE: + *node = pool->add(new AssertNumRowsNode(pool, tnode, descs)); + return Status::OK(); + default: map::const_iterator i = _TPlanNodeType_VALUES_TO_NAMES.find(tnode.node_type); diff --git a/fe/src/main/java/org/apache/doris/analysis/AssertNumRowsElement.java b/fe/src/main/java/org/apache/doris/analysis/AssertNumRowsElement.java new file mode 100644 index 0000000000..3bc625e157 --- /dev/null +++ b/fe/src/main/java/org/apache/doris/analysis/AssertNumRowsElement.java @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.analysis; + +public class AssertNumRowsElement { + private long desiredNumOfRows; + private String subqueryString; + + public AssertNumRowsElement(long desiredNumOfRows, String subqueryString) { + this.desiredNumOfRows = desiredNumOfRows; + this.subqueryString = subqueryString; + } + + public long getDesiredNumOfRows() { + return desiredNumOfRows; + } + + public String getSubqueryString() { + return subqueryString; + } +} diff --git a/fe/src/main/java/org/apache/doris/analysis/BetweenPredicate.java b/fe/src/main/java/org/apache/doris/analysis/BetweenPredicate.java index 939ad16762..5fad60ace6 100644 --- a/fe/src/main/java/org/apache/doris/analysis/BetweenPredicate.java +++ b/fe/src/main/java/org/apache/doris/analysis/BetweenPredicate.java @@ -70,6 +70,12 @@ public class BetweenPredicate extends Predicate { throw new AnalysisException("Comparison between subqueries is not " + "supported in a BETWEEN predicate: " + toSql()); } + // if children has subquery, it will be written and reanalyzed in the future. + if (children.get(0) instanceof Subquery + || children.get(1) instanceof Subquery + || children.get(2) instanceof Subquery) { + return; + } analyzer.castAllToCompatibleType(children); } diff --git a/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java b/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java index 79ea427a22..34897aadbe 100644 --- a/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java +++ b/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java @@ -137,7 +137,7 @@ public class BinaryPredicate extends Predicate implements Writable { private Operator op; // check if left is slot and right isnot slot. private Boolean slotIsleft = null; - + // for restoring public BinaryPredicate() { super(); @@ -243,7 +243,6 @@ public class BinaryPredicate extends Predicate implements Writable { @Override public void vectorizedAnalyze(Analyzer analyzer) { super.vectorizedAnalyze(analyzer); - Type cmpType = getCmpType(); Function match = null; //OpcodeRegistry.BuiltinFunction match = OpcodeRegistry.instance().getFunctionInfo( @@ -305,7 +304,7 @@ public class BinaryPredicate extends Predicate implements Writable { && (t2 == PrimitiveType.BIGINT || t2 == PrimitiveType.LARGEINT)) { return Type.LARGEINT; } - + return Type.DOUBLE; } @@ -314,13 +313,18 @@ public class BinaryPredicate extends Predicate implements Writable { super.analyzeImpl(analyzer); for (Expr expr : children) { - if (expr instanceof Subquery && !expr.getType().isScalarType()) { - throw new AnalysisException("BinaryPredicate can't contain subquery or non scalar type"); - } + if (expr instanceof Subquery && !((Subquery) expr).returnsScalarColumn()) { + String msg = "Subquery of binary predicate must return a single column: " + expr.toSql(); + throw new AnalysisException(msg); + } + } + + // if children has subquery, it will be written and reanalyzed in the future. + if (children.get(0) instanceof Subquery || children.get(1) instanceof Subquery) { + return; } Type cmpType = getCmpType(); - // Ignore return value because type is always bool for predicates. castBinaryOp(cmpType); diff --git a/fe/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index 3720883e8a..4455dfca7a 100644 --- a/fe/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -545,6 +545,9 @@ public class FunctionCallExpr extends Expr { analyzeBuiltinAggFunction(analyzer); if (fnName.getFunction().equalsIgnoreCase("sum")) { + if (this.children.isEmpty()) { + throw new AnalysisException("The " + fnName + " function must has one input param"); + } Type type = getChild(0).type.getMaxResolutionType(); fn = getBuiltinFunction(analyzer, fnName.getFunction(), new Type[]{type}, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); diff --git a/fe/src/main/java/org/apache/doris/analysis/QueryStmt.java b/fe/src/main/java/org/apache/doris/analysis/QueryStmt.java index 00567e8930..8f59ce1e21 100644 --- a/fe/src/main/java/org/apache/doris/analysis/QueryStmt.java +++ b/fe/src/main/java/org/apache/doris/analysis/QueryStmt.java @@ -54,7 +54,11 @@ public abstract class QueryStmt extends StatementBase { protected WithClause withClause_; protected ArrayList orderByElements; + // Limit element could not be null, the default limit element is NO_LIMIT protected LimitElement limitElement; + // This is a internal element which is used to query plan. + // It will not change the origin stmt and present in toSql. + protected AssertNumRowsElement assertNumRowsElement; /** * For a select statment: @@ -185,10 +189,6 @@ public abstract class QueryStmt extends StatementBase { return evaluateOrderBy; } - public void setEvaluateOrderBy(boolean evaluateOrderBy) { - this.evaluateOrderBy = evaluateOrderBy; - } - public ArrayList getBaseTblResultExprs() { return baseTblResultExprs; } @@ -419,6 +419,14 @@ public abstract class QueryStmt extends StatementBase { return limitElement.getOffset(); } + public void setAssertNumRowsElement(int desiredNumOfRows) { + this.assertNumRowsElement = new AssertNumRowsElement(desiredNumOfRows, toSql()); + } + + public AssertNumRowsElement getAssertNumRowsElement() { + return assertNumRowsElement; + } + public void setIsExplain(boolean isExplain) { this.isExplain = isExplain; } diff --git a/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java b/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java index cf5f4eb5d0..35ac54c04c 100644 --- a/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java +++ b/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java @@ -30,10 +30,10 @@ import org.apache.doris.common.AnalysisException; import org.apache.doris.common.ColumnAliasGenerator; import org.apache.doris.common.ErrorCode; import org.apache.doris.common.ErrorReport; -import org.apache.doris.common.UserException; import org.apache.doris.common.Pair; import org.apache.doris.common.TableAliasGenerator; import org.apache.doris.common.TreeNode; +import org.apache.doris.common.UserException; import org.apache.doris.common.util.SqlUtils; import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.qe.ConnectContext; diff --git a/fe/src/main/java/org/apache/doris/analysis/StmtRewriter.java b/fe/src/main/java/org/apache/doris/analysis/StmtRewriter.java index c92d95a088..ddef1d18c9 100644 --- a/fe/src/main/java/org/apache/doris/analysis/StmtRewriter.java +++ b/fe/src/main/java/org/apache/doris/analysis/StmtRewriter.java @@ -17,20 +17,21 @@ package org.apache.doris.analysis; -import java.util.ArrayList; -import java.util.List; - -import com.google.common.collect.Iterables; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.UserException; + import com.google.common.base.Preconditions; import com.google.common.base.Predicates; +import com.google.common.collect.Iterables; import com.google.common.collect.Lists; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; + /** * Class representing a statement rewriter. A statement rewriter performs subquery * unnesting on an analyzed parse tree. @@ -249,6 +250,21 @@ public class StmtRewriter { } } + + /** + * Situation: The expr is a binary predicate and the type of subquery is not scalar type. + * Rewrite: The stmt of inline view is added an assert condition (return error if row count > 1). + * Input params: + * expr: k1=(select k1 from t2) + * origin inline view: (select k1 $a from t2) $b + * stmt: select * from t1 where k1=(select k1 from t2); + * Output params: + * rewritten inline view: (select k1 $a from t2 (assert row count: return error if row count > 1)) $b + */ + private static void rewriteBinaryPredicateWithSubquery(InlineViewRef inlineViewRef) { + inlineViewRef.getViewStmt().setAssertNumRowsElement(1); + } + /** * Replace an ExistsPredicate that contains a subquery with a BoolLiteral if we * can determine its result without evaluating it. Return null if the result of the @@ -315,8 +331,17 @@ public class StmtRewriter { } } + /** + * origin stmt: select * from t1 where t1=(select k1 from t2); + * + * @param stmt select * from t1 where true; + * @param expr t1=(select k1 from t2). The expr has already be rewritten. + * @param analyzer + * @return + * @throws AnalysisException + */ private static boolean mergeExpr(SelectStmt stmt, Expr expr, - Analyzer analyzer) throws AnalysisException { + Analyzer analyzer) throws AnalysisException { // LOG.warn("dhc mergeExpr stmt={} expr={}", stmt, expr); LOG.debug("SUBQUERY mergeExpr stmt={} expr={}", stmt.toSql(), expr.toSql()); Preconditions.checkNotNull(expr); @@ -331,9 +356,11 @@ public class StmtRewriter { // to eliminate any chance that column aliases from the parent query could reference // select items from the inline view after the rewrite. List colLabels = Lists.newArrayList(); + // add a new alias for all of columns in subquery for (int i = 0; i < subqueryStmt.getColLabels().size(); ++i) { colLabels.add(subqueryStmt.getColumnAliasGenerator().getNextAlias()); } + // (select k1 $a from t2) $b InlineViewRef inlineView = new InlineViewRef( stmt.getTableAliasGenerator().getNextAlias(), subqueryStmt, colLabels); @@ -362,12 +389,24 @@ public class StmtRewriter { lhsExprs, rhsExprs, updateGroupBy); } + /** + * Situation: The expr is a uncorrelated subquery for outer stmt. + * Rewrite: Add a limit 1 for subquery. + * origin stmt: select * from t1 where exists (select * from table2); + * expr: exists (select * from table2) + * outer stmt: select * from t1 + * onClauseConjuncts: empty. + */ if (expr instanceof ExistsPredicate && onClauseConjuncts.isEmpty()) { // For uncorrelated subqueries, we limit the number of rows returned by the // subquery. subqueryStmt.setLimit(1); } + if (expr instanceof BinaryPredicate && !expr.getSubquery().getType().isScalarType()) { + rewriteBinaryPredicateWithSubquery(inlineView); + } + // Analyzing the inline view trigger reanalysis of the subquery's select statement. // However the statement is already analyzed and since statement analysis is not // idempotent, the analysis needs to be reset (by a call to clone()). @@ -383,7 +422,7 @@ public class StmtRewriter { stmt.fromClause_.add(inlineView); JoinOperator joinOp = JoinOperator.LEFT_SEMI_JOIN; - // Create a join conjunct from the expr that contains a subquery. + // Create a join conjunct from the expr that contains a subquery. Expr joinConjunct = createJoinConjunct(expr, inlineView, analyzer, !onClauseConjuncts.isEmpty()); if (joinConjunct != null) { @@ -481,7 +520,7 @@ public class StmtRewriter { // TODO: Requires support for non-equi joins. boolean hasGroupBy = ((SelectStmt) inlineView.getViewStmt()).hasGroupByClause(); // boolean hasGroupBy = false; - if (!expr.getSubquery().isScalarSubquery() + if (!expr.getSubquery().returnsScalarColumn() || (!(hasGroupBy && stmt.selectList.isDistinct()) && hasGroupBy)) { throw new AnalysisException("Unsupported predicate with subquery: " + expr.toSql()); @@ -863,11 +902,8 @@ public class StmtRewriter { pred.analyze(analyzer); return pred; } - // Only scalar subqueries are supported + Subquery subquery = exprWithSubquery.getSubquery(); - if (!subquery.isScalarSubquery()) { - throw new AnalysisException("Unsupported predicate with a non-scalar subquery: " + subquery.toSql()); - } ExprSubstitutionMap smap = new ExprSubstitutionMap(); SelectListItem item = ((SelectStmt) inlineView.getViewStmt()).getSelectList().getItems().get(0); diff --git a/fe/src/main/java/org/apache/doris/load/loadv2/BrokerLoadJob.java b/fe/src/main/java/org/apache/doris/load/loadv2/BrokerLoadJob.java index 6945a4450f..4614727674 100644 --- a/fe/src/main/java/org/apache/doris/load/loadv2/BrokerLoadJob.java +++ b/fe/src/main/java/org/apache/doris/load/loadv2/BrokerLoadJob.java @@ -18,7 +18,6 @@ package org.apache.doris.load.loadv2; -import com.google.common.collect.Maps; import org.apache.doris.analysis.BrokerDesc; import org.apache.doris.analysis.DataDescription; import org.apache.doris.analysis.LoadStmt; @@ -56,6 +55,7 @@ import org.apache.doris.transaction.TransactionState; import com.google.common.base.Joiner; import com.google.common.base.Strings; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import com.google.common.collect.Sets; import org.apache.logging.log4j.LogManager; diff --git a/fe/src/main/java/org/apache/doris/planner/AssertNumRowsNode.java b/fe/src/main/java/org/apache/doris/planner/AssertNumRowsNode.java new file mode 100644 index 0000000000..033935c0aa --- /dev/null +++ b/fe/src/main/java/org/apache/doris/planner/AssertNumRowsNode.java @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.planner; + +import org.apache.doris.analysis.AssertNumRowsElement; +import org.apache.doris.thrift.TAssertNumRowsNode; +import org.apache.doris.thrift.TExplainLevel; +import org.apache.doris.thrift.TPlanNode; +import org.apache.doris.thrift.TPlanNodeType; + +/** + * Assert num rows node is used to determine whether the number of rows is less then desired num of rows. + * The rows are the result of subqueryString. + * If the number of rows is more then the desired num of rows, the query will be cancelled. + * The cancelled reason will be reported by Backend and displayed back to the user. + */ +public class AssertNumRowsNode extends PlanNode { + + private long desiredNumOfRows; + private String subqueryString; + + public AssertNumRowsNode(PlanNodeId id, PlanNode input, AssertNumRowsElement assertNumRowsElement) { + super(id, "ASSERT NUMBER OF ROWS"); + this.desiredNumOfRows = assertNumRowsElement.getDesiredNumOfRows(); + this.subqueryString = assertNumRowsElement.getSubqueryString(); + this.children.add(input); + this.tupleIds = input.getTupleIds(); + this.tblRefIds = input.getTblRefIds(); + this.nullableTupleIds = input.getNullableTupleIds(); + } + + @Override + protected String getNodeExplainString(String prefix, TExplainLevel detailLevel) { + StringBuilder output = new StringBuilder() + .append(prefix + "assert number of rows: " ) + .append(desiredNumOfRows + "\n"); + return output.toString(); + } + + @Override + protected void toThrift(TPlanNode msg) { + msg.node_type = TPlanNodeType.ASSERT_NUM_ROWS_NODE; + msg.assert_num_rows_node = new TAssertNumRowsNode(); + msg.assert_num_rows_node.setDesired_num_rows(desiredNumOfRows); + msg.assert_num_rows_node.setSubquery_string(subqueryString); + } +} diff --git a/fe/src/main/java/org/apache/doris/planner/DistributedPlanner.java b/fe/src/main/java/org/apache/doris/planner/DistributedPlanner.java index 280eb28b43..ec62a785ee 100644 --- a/fe/src/main/java/org/apache/doris/planner/DistributedPlanner.java +++ b/fe/src/main/java/org/apache/doris/planner/DistributedPlanner.java @@ -223,6 +223,8 @@ public class DistributedPlanner { result = createAnalyticFragment(root, childFragments.get(0), fragments); } else if (root instanceof EmptySetNode) { result = new PlanFragment(ctx_.getNextFragmentId(), root, DataPartition.UNPARTITIONED); + } else if (root instanceof AssertNumRowsNode) { + result = createAssertFragment(root, childFragments.get(0)); } else { throw new UserException( "Cannot create plan fragment for this node type: " + root.getExplainString()); @@ -1051,4 +1053,24 @@ public class DistributedPlanner { return mergeFragment; } + private PlanFragment createAssertFragment(PlanNode assertRowCountNode, PlanFragment inputFragment) + throws UserException { + Preconditions.checkState(assertRowCountNode instanceof AssertNumRowsNode); + if (!inputFragment.isPartitioned()) { + inputFragment.addPlanRoot(assertRowCountNode); + return inputFragment; + } + + // Create a new fragment for assert row count node + PlanFragment mergeFragment = createParentFragment(inputFragment, DataPartition.UNPARTITIONED); + ExchangeNode exchNode = (ExchangeNode) mergeFragment.getPlanRoot(); + mergeFragment.addPlanRoot(assertRowCountNode); + + // reset the stat of assert row count node + exchNode.computeStats(ctx_.getRootAnalyzer()); + assertRowCountNode.computeStats(ctx_.getRootAnalyzer()); + + return mergeFragment; + } + } diff --git a/fe/src/main/java/org/apache/doris/planner/HashJoinNode.java b/fe/src/main/java/org/apache/doris/planner/HashJoinNode.java index 9595259dba..269173b736 100644 --- a/fe/src/main/java/org/apache/doris/planner/HashJoinNode.java +++ b/fe/src/main/java/org/apache/doris/planner/HashJoinNode.java @@ -272,7 +272,6 @@ public class HashJoinNode extends PlanNode { } @Override - protected String getNodeExplainString(String detailPrefix, TExplainLevel detailLevel) { String distrModeStr = (distrMode != DistributionMode.NONE) ? (" (" + distrMode.toString() + ")") : ""; diff --git a/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java b/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java index 30d7ba0250..8a4a7b4a36 100644 --- a/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java +++ b/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java @@ -20,6 +20,7 @@ package org.apache.doris.planner; import org.apache.doris.analysis.AggregateInfo; import org.apache.doris.analysis.AnalyticInfo; import org.apache.doris.analysis.Analyzer; +import org.apache.doris.analysis.AssertNumRowsElement; import org.apache.doris.analysis.BaseTableRef; import org.apache.doris.analysis.BinaryPredicate; import org.apache.doris.analysis.CaseExpr; @@ -269,6 +270,10 @@ public class SingleNodePlanner { root.computeStats(analyzer); } + // adding assert node at the end of single node planner + if (stmt.getAssertNumRowsElement() != null) { + root = createAssertRowCountNode(root, stmt.getAssertNumRowsElement(), analyzer); + } return root; } @@ -1560,6 +1565,13 @@ public class SingleNodePlanner { return result; } + private PlanNode createAssertRowCountNode(PlanNode input, AssertNumRowsElement assertNumRowsElement, + Analyzer analyzer) throws UserException { + AssertNumRowsNode root = new AssertNumRowsNode(ctx_.getNextNodeId(), input, assertNumRowsElement); + root.init(analyzer); + return root; + } + /** * According to the way to materialize slots from top to bottom, Materialization will prune columns * which are not referenced by Statement outside. However, in some cases, in order to ensure The diff --git a/fe/src/test/java/org/apache/doris/analysis/BetweenPredicateTest.java b/fe/src/test/java/org/apache/doris/analysis/BetweenPredicateTest.java new file mode 100644 index 0000000000..4ee5c213a6 --- /dev/null +++ b/fe/src/test/java/org/apache/doris/analysis/BetweenPredicateTest.java @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.analysis; + +import org.apache.doris.common.AnalysisException; + +import org.junit.Assert; +import org.junit.Test; + +import mockit.Injectable; +import mockit.Mocked; + +public class BetweenPredicateTest { + @Mocked Analyzer analyzer; + + @Test + public void testWithCompareAndBoundSubquery(@Injectable Subquery compareExpr, + @Injectable Subquery lowerBound, + @Injectable Expr upperBound) { + BetweenPredicate betweenPredicate = new BetweenPredicate(compareExpr, lowerBound, upperBound, false); + try { + betweenPredicate.analyzeImpl(analyzer); + Assert.fail(); + } catch (AnalysisException e) { + } + } + + @Test + public void testWithBoundSubquery(@Injectable Expr compareExpr, + @Injectable Subquery lowerBound, + @Injectable Subquery upperBound) { + BetweenPredicate betweenPredicate = new BetweenPredicate(compareExpr, lowerBound, upperBound, false); + try { + betweenPredicate.analyzeImpl(analyzer); + } catch (AnalysisException e) { + Assert.fail(e.getMessage()); + } + } +} diff --git a/fe/src/test/java/org/apache/doris/analysis/BinaryPredicateTest.java b/fe/src/test/java/org/apache/doris/analysis/BinaryPredicateTest.java new file mode 100644 index 0000000000..757f36ee57 --- /dev/null +++ b/fe/src/test/java/org/apache/doris/analysis/BinaryPredicateTest.java @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.analysis; + +import org.apache.doris.common.AnalysisException; + +import org.junit.Assert; +import org.junit.Test; + +import mockit.Deencapsulation; +import mockit.Expectations; +import mockit.Injectable; +import mockit.Mocked; + +public class BinaryPredicateTest { + + @Mocked + Analyzer analyzer; + + @Test + public void testMultiColumnSubquery(@Injectable Expr child0, + @Injectable Subquery child1) { + BinaryPredicate binaryPredicate = new BinaryPredicate(BinaryPredicate.Operator.EQ, child0, child1); + new Expectations() { + { + child1.returnsScalarColumn(); + result = false; + } + }; + + try { + binaryPredicate.analyzeImpl(analyzer); + Assert.fail(); + } catch (AnalysisException e) { + } + } + + @Test + public void testSingleColumnSubquery(@Injectable Expr child0, + @Injectable Subquery child1) { + BinaryPredicate binaryPredicate = new BinaryPredicate(BinaryPredicate.Operator.EQ, child0, child1); + new Expectations() { + { + child1.returnsScalarColumn(); + result = true; + } + }; + + try { + binaryPredicate.analyzeImpl(analyzer); + Assert.assertSame(null, Deencapsulation.getField(binaryPredicate, "fn")); + } catch (AnalysisException e) { + Assert.fail(e.getMessage()); + } + } +} diff --git a/fe/src/test/java/org/apache/doris/planner/DistributedPlannerTest.java b/fe/src/test/java/org/apache/doris/planner/DistributedPlannerTest.java new file mode 100644 index 0000000000..2a6bc838ec --- /dev/null +++ b/fe/src/test/java/org/apache/doris/planner/DistributedPlannerTest.java @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.planner; + +import org.apache.doris.analysis.TupleId; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.List; +import java.util.Set; + +import mockit.Deencapsulation; +import mockit.Expectations; +import mockit.Injectable; +import mockit.Mocked; + +public class DistributedPlannerTest { + + @Mocked + PlannerContext plannerContext; + + @Test + public void testAssertFragmentWithDistributedInput(@Injectable AssertNumRowsNode assertNumRowsNode, + @Injectable PlanFragment inputFragment, + @Injectable PlanNodeId planNodeId, + @Injectable PlanFragmentId planFragmentId, + @Injectable PlanNode inputPlanRoot, + @Injectable TupleId tupleId) { + DistributedPlanner distributedPlanner = new DistributedPlanner(plannerContext); + + List tupleIdList = Lists.newArrayList(tupleId); + Set tupleIdSet = Sets.newHashSet(tupleId); + Deencapsulation.setField(inputPlanRoot, "tupleIds", tupleIdList); + Deencapsulation.setField(inputPlanRoot, "tblRefIds", tupleIdList); + Deencapsulation.setField(inputPlanRoot, "nullableTupleIds", Sets.newHashSet(tupleId)); + Deencapsulation.setField(inputPlanRoot, "conjuncts", Lists.newArrayList()); + new Expectations() { + { + inputFragment.isPartitioned(); + result = true; + plannerContext.getNextNodeId(); + result = planNodeId; + plannerContext.getNextFragmentId(); + result = planFragmentId; + inputFragment.getPlanRoot(); + result = inputPlanRoot; + inputPlanRoot.getTupleIds(); + result = tupleIdList; + inputPlanRoot.getTblRefIds(); + result = tupleIdList; + inputPlanRoot.getNullableTupleIds(); + result = tupleIdSet; + assertNumRowsNode.getChildren(); + result = inputPlanRoot; + } + }; + + PlanFragment assertFragment = Deencapsulation.invoke(distributedPlanner, "createAssertFragment", + assertNumRowsNode, inputFragment); + Assert.assertFalse(assertFragment.isPartitioned()); + Assert.assertSame(assertNumRowsNode, assertFragment.getPlanRoot()); + } + + @Test + public void testAssertFragmentWithUnpartitionInput(@Injectable AssertNumRowsNode assertNumRowsNode, + @Injectable PlanFragment inputFragment){ + DistributedPlanner distributedPlanner = new DistributedPlanner(plannerContext); + + PlanFragment assertFragment = Deencapsulation.invoke(distributedPlanner, "createAssertFragment", + assertNumRowsNode, inputFragment); + Assert.assertSame(assertFragment, inputFragment); + Assert.assertTrue(assertFragment.getPlanRoot() instanceof AssertNumRowsNode); + } + +} diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index b13c70d3d2..6a5b9aa3d2 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -45,7 +45,8 @@ enum TPlanNodeType { EMPTY_SET_NODE, UNION_NODE, ES_SCAN_NODE, - ES_HTTP_SCAN_NODE + ES_HTTP_SCAN_NODE, + ASSERT_NUM_ROWS_NODE } // phases of an execution node @@ -568,6 +569,11 @@ struct TBackendResourceProfile { 4: optional i64 max_row_buffer_size = 4194304 //TODO chenhao } +struct TAssertNumRowsNode { + 1: optional i64 desired_num_rows; + 2: optional string subquery_string; +} + // This is essentially a union of all messages corresponding to subclasses // of PlanNode. struct TPlanNode { @@ -605,6 +611,7 @@ struct TPlanNode { 28: optional TUnionNode union_node 29: optional TBackendResourceProfile resource_profile 30: optional TEsScanNode es_scan_node + 31: optional TAssertNumRowsNode assert_num_rows_node } // A flattened representation of a tree of PlanNodes, obtained by depth-first