From f90da720787a4a8c79c353d21c2d73d3180ff953 Mon Sep 17 00:00:00 2001 From: yangzhg <780531911@qq.com> Date: Fri, 8 May 2020 12:49:48 +0800 Subject: [PATCH] [Planner]Enhance AssertNumRowsNode (#3485) Enhance AssertNumRowsNode to support equal, less than, greater than,... assert conditions --- be/src/exec/assert_num_rows_node.cpp | 65 +++++++++++++++---- be/src/exec/assert_num_rows_node.h | 8 +-- .../apache/doris/analysis/ArithmeticExpr.java | 2 +- .../doris/analysis/AssertNumRowsElement.java | 38 ++++++++++- .../doris/analysis/BinaryPredicate.java | 2 +- .../org/apache/doris/analysis/QueryStmt.java | 4 +- .../doris/planner/AssertNumRowsNode.java | 7 +- gensrc/thrift/PlanNodes.thrift | 10 +++ 8 files changed, 112 insertions(+), 24 deletions(-) diff --git a/be/src/exec/assert_num_rows_node.cpp b/be/src/exec/assert_num_rows_node.cpp index 407aeb0d95..81aeb1e81e 100644 --- a/be/src/exec/assert_num_rows_node.cpp +++ b/be/src/exec/assert_num_rows_node.cpp @@ -17,19 +17,24 @@ #include "exec/assert_num_rows_node.h" +#include "gen_cpp/PlanNodes_types.h" +#include "gutil/strings/substitute.h" #include "runtime/row_batch.h" #include "runtime/runtime_state.h" #include "util/runtime_profile.h" -#include "gen_cpp/PlanNodes_types.h" -#include "gutil/strings/substitute.h" namespace doris { -AssertNumRowsNode::AssertNumRowsNode( - ObjectPool* pool, const TPlanNode& tnode, const DescriptorTbl& descs) : - ExecNode(pool, tnode, descs), - _desired_num_rows(tnode.assert_num_rows_node.desired_num_rows), - _subquery_string(tnode.assert_num_rows_node.subquery_string) { +AssertNumRowsNode::AssertNumRowsNode(ObjectPool* pool, const TPlanNode& tnode, + const DescriptorTbl& descs) + : ExecNode(pool, tnode, descs), + _desired_num_rows(tnode.assert_num_rows_node.desired_num_rows), + _subquery_string(tnode.assert_num_rows_node.subquery_string) { + if (tnode.assert_num_rows_node.__isset.assertion) { + _assertion = tnode.assert_num_rows_node.assertion; + } else { + _assertion = TAssertion::LE; // just comptiable for the previous code + } } Status AssertNumRowsNode::init(const TPlanNode& tnode, RuntimeState* state) { @@ -56,12 +61,46 @@ Status AssertNumRowsNode::get_next(RuntimeState* state, RowBatch* output_batch, output_batch->reset(); child(0)->get_next(state, output_batch, eos); _num_rows_returned += output_batch->num_rows(); - if (_num_rows_returned > _desired_num_rows) { - LOG(INFO) << "Expected no more than " << _desired_num_rows << " to be returned by expression " - << _subquery_string; + bool assert_res = false; + switch (_assertion) { + case TAssertion::EQ: + assert_res = _num_rows_returned == _desired_num_rows; + break; + case TAssertion::NE: + assert_res = _num_rows_returned != _desired_num_rows; + break; + case TAssertion::LT: + assert_res = _num_rows_returned < _desired_num_rows; + break; + case TAssertion::LE: + assert_res = _num_rows_returned <= _desired_num_rows; + break; + case TAssertion::GT: + assert_res = _num_rows_returned > _desired_num_rows; + break; + case TAssertion::GE: + assert_res = _num_rows_returned >= _desired_num_rows; + break; + default: + break; + } + + if (!assert_res) { + auto to_string_lamba = [](TAssertion::type assertion) { + std::map::const_iterator it = + _TAssertion_VALUES_TO_NAMES.find(assertion); + + if (it == _TAggregationOp_VALUES_TO_NAMES.end()) { + return "NULL"; + } else { + return it->second; + } + }; + LOG(INFO) << "Expected " << to_string_lamba(_assertion) << " " << _desired_num_rows + << " to be returned by expression " << _subquery_string; return Status::Cancelled(strings::Substitute( - "Expected no more than $0 to be returned by expression $1", - _desired_num_rows, _subquery_string)); + "Expected $0 $1 to be returned by expression $2", to_string_lamba(_assertion), + _desired_num_rows, _subquery_string)); } COUNTER_SET(_rows_returned_counter, _num_rows_returned); return Status::OK(); @@ -74,4 +113,4 @@ Status AssertNumRowsNode::close(RuntimeState* state) { return ExecNode::close(state); } -} \ No newline at end of file +} // namespace doris diff --git a/be/src/exec/assert_num_rows_node.h b/be/src/exec/assert_num_rows_node.h index 0ed5ddcf32..8141147403 100644 --- a/be/src/exec/assert_num_rows_node.h +++ b/be/src/exec/assert_num_rows_node.h @@ -20,12 +20,11 @@ namespace doris { -// Node for assert row count: -// - +// Node for assert row count class AssertNumRowsNode : public ExecNode { public: AssertNumRowsNode(ObjectPool* pool, const TPlanNode& tnode, const DescriptorTbl& descs); - virtual ~AssertNumRowsNode() {}; + virtual ~AssertNumRowsNode(){}; virtual Status init(const TPlanNode& tnode, RuntimeState* state = nullptr); virtual Status prepare(RuntimeState* state); @@ -36,6 +35,7 @@ public: private: int64_t _desired_num_rows; const std::string _subquery_string; + TAssertion::type _assertion; }; -} \ No newline at end of file +} // namespace doris diff --git a/fe/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java b/fe/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java index 078245d492..beaa3cb059 100644 --- a/fe/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java +++ b/fe/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java @@ -300,7 +300,7 @@ public class ArithmeticExpr extends Expr { * subquery stmt: select k1 from t2 (assert row count: return error if row count > 1 ) */ if (!subquery.getType().isScalarType()) { - subquery.getStatement().setAssertNumRowsElement(1); + subquery.getStatement().setAssertNumRowsElement(1, AssertNumRowsElement.Assertion.LE); } } } diff --git a/fe/src/main/java/org/apache/doris/analysis/AssertNumRowsElement.java b/fe/src/main/java/org/apache/doris/analysis/AssertNumRowsElement.java index 3bc625e157..c37089c30f 100644 --- a/fe/src/main/java/org/apache/doris/analysis/AssertNumRowsElement.java +++ b/fe/src/main/java/org/apache/doris/analysis/AssertNumRowsElement.java @@ -17,13 +17,45 @@ package org.apache.doris.analysis; +import org.apache.doris.thrift.TAssertion; + public class AssertNumRowsElement { + public enum Assertion { + EQ, // val1 == val2 + NE, // val1 != val2 + LT, // val1 < val2 + LE, // val1 <= val2 + GT, // val1 > val2 + GE; // val1 >= val2 + + public TAssertion toThrift() { + switch (this) { + case EQ: + return TAssertion.EQ; + case NE: + return TAssertion.NE; + case LT: + return TAssertion.LT; + case LE: + return TAssertion.LE; + case GT: + return TAssertion.GT; + case GE: + return TAssertion.GE; + default: + return null; + } + } + } + private long desiredNumOfRows; private String subqueryString; + private Assertion assertion; - public AssertNumRowsElement(long desiredNumOfRows, String subqueryString) { + public AssertNumRowsElement(long desiredNumOfRows, String subqueryString, Assertion assertion) { this.desiredNumOfRows = desiredNumOfRows; this.subqueryString = subqueryString; + this.assertion = assertion; } public long getDesiredNumOfRows() { @@ -33,4 +65,8 @@ public class AssertNumRowsElement { public String getSubqueryString() { return subqueryString; } + + public Assertion getAssertion() { + return assertion; + } } diff --git a/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java b/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java index 0c7d6ffac0..d7220301c1 100644 --- a/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java +++ b/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java @@ -331,7 +331,7 @@ public class BinaryPredicate extends Predicate implements Writable { * subquery stmt: select k1 from t2 (assert row count: return error if row count > 1 ) */ if (!subquery.getType().isScalarType()) { - subquery.getStatement().setAssertNumRowsElement(1); + subquery.getStatement().setAssertNumRowsElement(1, AssertNumRowsElement.Assertion.LE); } } } diff --git a/fe/src/main/java/org/apache/doris/analysis/QueryStmt.java b/fe/src/main/java/org/apache/doris/analysis/QueryStmt.java index 467fd6b8dd..e6f0bfb0a7 100644 --- a/fe/src/main/java/org/apache/doris/analysis/QueryStmt.java +++ b/fe/src/main/java/org/apache/doris/analysis/QueryStmt.java @@ -478,8 +478,8 @@ public abstract class QueryStmt extends StatementBase { return limitElement.getOffset(); } - public void setAssertNumRowsElement(int desiredNumOfRows) { - this.assertNumRowsElement = new AssertNumRowsElement(desiredNumOfRows, toSql()); + public void setAssertNumRowsElement(int desiredNumOfRows, AssertNumRowsElement.Assertion assertion) { + this.assertNumRowsElement = new AssertNumRowsElement(desiredNumOfRows, toSql(), assertion); } public AssertNumRowsElement getAssertNumRowsElement() { diff --git a/fe/src/main/java/org/apache/doris/planner/AssertNumRowsNode.java b/fe/src/main/java/org/apache/doris/planner/AssertNumRowsNode.java index 033935c0aa..e0a6c5a7a4 100644 --- a/fe/src/main/java/org/apache/doris/planner/AssertNumRowsNode.java +++ b/fe/src/main/java/org/apache/doris/planner/AssertNumRowsNode.java @@ -33,11 +33,13 @@ public class AssertNumRowsNode extends PlanNode { private long desiredNumOfRows; private String subqueryString; + private AssertNumRowsElement.Assertion assertion; public AssertNumRowsNode(PlanNodeId id, PlanNode input, AssertNumRowsElement assertNumRowsElement) { super(id, "ASSERT NUMBER OF ROWS"); this.desiredNumOfRows = assertNumRowsElement.getDesiredNumOfRows(); this.subqueryString = assertNumRowsElement.getSubqueryString(); + this.assertion = assertNumRowsElement.getAssertion(); this.children.add(input); this.tupleIds = input.getTupleIds(); this.tblRefIds = input.getTblRefIds(); @@ -47,8 +49,8 @@ public class AssertNumRowsNode extends PlanNode { @Override protected String getNodeExplainString(String prefix, TExplainLevel detailLevel) { StringBuilder output = new StringBuilder() - .append(prefix + "assert number of rows: " ) - .append(desiredNumOfRows + "\n"); + .append(prefix + "assert number of rows: ") + .append(assertion).append(" ").append(desiredNumOfRows).append("\n"); return output.toString(); } @@ -58,5 +60,6 @@ public class AssertNumRowsNode extends PlanNode { msg.assert_num_rows_node = new TAssertNumRowsNode(); msg.assert_num_rows_node.setDesired_num_rows(desiredNumOfRows); msg.assert_num_rows_node.setSubquery_string(subqueryString); + msg.assert_num_rows_node.setAssertion(assertion.toThrift()); } } diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index 5f04435524..1a5f3e19db 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -620,9 +620,19 @@ struct TBackendResourceProfile { 4: optional i64 max_row_buffer_size = 4194304 //TODO chenhao } +enum TAssertion { + EQ, // val1 == val2 + NE, // val1 != val2 + LT, // val1 < val2 + LE, // val1 <= val2 + GT, // val1 > val2 + GE // val1 >= val2 +} + struct TAssertNumRowsNode { 1: optional i64 desired_num_rows; 2: optional string subquery_string; + 3: optional TAssertion assertion; } // This is essentially a union of all messages corresponding to subclasses