diff --git a/be/src/exprs/arithmetic_expr.cpp b/be/src/exprs/arithmetic_expr.cpp index dc4399b0f0..23668805a6 100644 --- a/be/src/exprs/arithmetic_expr.cpp +++ b/be/src/exprs/arithmetic_expr.cpp @@ -21,6 +21,10 @@ namespace doris { +std::set ArithmeticExpr::_s_valid_fn_names = { + "add", "subtract", "multiply", "divide", "int_divide", + "mod", "bitand", "bitor", "bitxor", "bitnot"}; + Expr* ArithmeticExpr::from_thrift(const TExprNode& node) { switch (node.opcode) { case TExprOpcode::ADD: @@ -48,6 +52,31 @@ Expr* ArithmeticExpr::from_thrift(const TExprNode& node) { return nullptr; } +Expr* ArithmeticExpr::from_fn_name(const TExprNode& node) { + std::string fn_name = node.fn.name.function_name; + if (fn_name == "add") { + return new AddExpr(node); + } else if (fn_name == "subtract") { + return new SubExpr(node); + } else if (fn_name == "multiply") { + return new MulExpr(node); + } else if (fn_name == "divide" || fn_name == "int_divide") { + return new DivExpr(node); + } else if (fn_name == "mod") { + return new ModExpr(node); + } else if (fn_name == "bitand") { + return new BitAndExpr(node); + } else if (fn_name == "bitor") { + return new BitOrExpr(node); + } else if (fn_name == "bitxor") { + return new BitXorExpr(node); + } else if (fn_name == "bitnot") { + return new BitNotExpr(node); + } + + return nullptr; +} + #define BINARY_OP_CHECK_ZERO_FN(TYPE, CLASS, FN, OP) \ TYPE CLASS::FN(ExprContext* context, TupleRow* row) { \ TYPE v1 = _children[0]->FN(context, row); \ @@ -159,4 +188,41 @@ BINARY_BIT_FNS(BitXorExpr, ^) BITNOT_OP_FN(LargeIntVal, get_large_int_val) BITNOT_FNS() + +#define DECIMAL_ARITHMETIC_OP(EXPR_NAME, OP) \ + DecimalV2Val EXPR_NAME::get_decimalv2_val(ExprContext* context, TupleRow* row) { \ + DecimalV2Val v1 = _children[0]->get_decimalv2_val(context, row); \ + DecimalV2Val v2 = _children[1]->get_decimalv2_val(context, row); \ + if (v1.is_null || v2.is_null) { \ + return DecimalV2Val::null(); \ + } \ + DecimalV2Value iv1 = DecimalV2Value::from_decimal_val(v1); \ + DecimalV2Value iv2 = DecimalV2Value::from_decimal_val(v2); \ + DecimalV2Value ir = iv1 OP iv2; \ + DecimalV2Val result; \ + ir.to_decimal_val(&result); \ + return result; \ + } + +#define DECIMAL_ARITHMETIC_OP_DIVIDE(EXPR_NAME, OP) \ + DecimalV2Val EXPR_NAME::get_decimalv2_val(ExprContext* context, TupleRow* row) { \ + DecimalV2Val v1 = _children[0]->get_decimalv2_val(context, row); \ + DecimalV2Val v2 = _children[1]->get_decimalv2_val(context, row); \ + if (v1.is_null || v2.is_null || v2.value() == 0) { \ + return DecimalV2Val::null(); \ + } \ + DecimalV2Value iv1 = DecimalV2Value::from_decimal_val(v1); \ + DecimalV2Value iv2 = DecimalV2Value::from_decimal_val(v2); \ + DecimalV2Value ir = iv1 OP iv2; \ + DecimalV2Val result; \ + ir.to_decimal_val(&result); \ + return result; \ + } + +DECIMAL_ARITHMETIC_OP(AddExpr, +); +DECIMAL_ARITHMETIC_OP(SubExpr, -); +DECIMAL_ARITHMETIC_OP(MulExpr, *); +DECIMAL_ARITHMETIC_OP_DIVIDE(DivExpr, /); +DECIMAL_ARITHMETIC_OP_DIVIDE(ModExpr, %); + } // namespace doris diff --git a/be/src/exprs/arithmetic_expr.h b/be/src/exprs/arithmetic_expr.h index 5fd56a9cd7..4062847479 100644 --- a/be/src/exprs/arithmetic_expr.h +++ b/be/src/exprs/arithmetic_expr.h @@ -18,6 +18,8 @@ #ifndef DORIS_BE_SRC_EXPRS_ARITHMETIC_EXPR_H #define DORIS_BE_SRC_EXPRS_ARITHMETIC_EXPR_H +#include + #include "common/object_pool.h" #include "exprs/expr.h" @@ -25,7 +27,9 @@ namespace doris { class ArithmeticExpr : public Expr { public: + static bool is_valid(std::string fn_name) { return _s_valid_fn_names.count(fn_name); } static Expr* from_thrift(const TExprNode& node); + static Expr* from_fn_name(const TExprNode& node); protected: enum BinaryOpType { @@ -42,6 +46,8 @@ protected: ArithmeticExpr(const TExprNode& node) : Expr(node) {} virtual ~ArithmeticExpr() {} + + static std::set _s_valid_fn_names; }; class AddExpr : public ArithmeticExpr { @@ -56,6 +62,7 @@ public: virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) override; virtual FloatVal get_float_val(ExprContext* context, TupleRow*) override; virtual DoubleVal get_double_val(ExprContext* context, TupleRow*) override; + virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) override; }; class SubExpr : public ArithmeticExpr { @@ -70,6 +77,7 @@ public: virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) override; virtual FloatVal get_float_val(ExprContext* context, TupleRow*) override; virtual DoubleVal get_double_val(ExprContext* context, TupleRow*) override; + virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) override; }; class MulExpr : public ArithmeticExpr { @@ -84,6 +92,7 @@ public: virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) override; virtual FloatVal get_float_val(ExprContext* context, TupleRow*) override; virtual DoubleVal get_double_val(ExprContext* context, TupleRow*) override; + virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) override; }; class DivExpr : public ArithmeticExpr { @@ -98,6 +107,7 @@ public: virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) override; virtual FloatVal get_float_val(ExprContext* context, TupleRow*) override; virtual DoubleVal get_double_val(ExprContext* context, TupleRow*) override; + virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) override; }; class ModExpr : public ArithmeticExpr { @@ -112,6 +122,7 @@ public: virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) override; virtual FloatVal get_float_val(ExprContext* context, TupleRow*) override; virtual DoubleVal get_double_val(ExprContext* context, TupleRow*) override; + virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) override; }; class BitAndExpr : public ArithmeticExpr { diff --git a/be/src/exprs/expr.cpp b/be/src/exprs/expr.cpp index 73f3775247..7429068959 100644 --- a/be/src/exprs/expr.cpp +++ b/be/src/exprs/expr.cpp @@ -360,6 +360,8 @@ Status Expr::create_expr(ObjectPool* pool, const TExprNode& texpr_node, Expr** e *expr = pool->add(new CoalesceExpr(texpr_node)); } else if (texpr_node.fn.binary_type == TFunctionBinaryType::RPC) { *expr = pool->add(new RPCFnCall(texpr_node)); + } else if (ArithmeticExpr::is_valid(texpr_node.fn.name.function_name)) { + *expr = pool->add(ArithmeticExpr::from_fn_name(texpr_node)); } else { *expr = pool->add(new ScalarFnCall(texpr_node)); } diff --git a/be/src/vec/functions/minus.cpp b/be/src/vec/functions/minus.cpp index e215a52f44..52b86a1085 100644 --- a/be/src/vec/functions/minus.cpp +++ b/be/src/vec/functions/minus.cpp @@ -34,6 +34,7 @@ struct MinusImpl { return static_cast(a) - b; } + template static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b) { return a - b; } diff --git a/be/src/vec/functions/modulo.cpp b/be/src/vec/functions/modulo.cpp index 0e8bf49244..28b0ec768d 100644 --- a/be/src/vec/functions/modulo.cpp +++ b/be/src/vec/functions/modulo.cpp @@ -18,6 +18,7 @@ // https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/Modulo.cpp // and modified by Doris +#include "runtime/decimalv2_value.h" #ifdef __SSE2__ #define LIBDIVIDE_SSE2 1 #endif @@ -47,6 +48,7 @@ struct ModuloImpl { } } + template static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b, NullMap& null_map, size_t index) { null_map[index] = b == DecimalV2Value(0); @@ -72,6 +74,7 @@ struct PModuloImpl { } } + template static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b, NullMap& null_map, size_t index) { null_map[index] = b == DecimalV2Value(0); diff --git a/be/src/vec/functions/multiply.cpp b/be/src/vec/functions/multiply.cpp index 73cc85443a..b2840e42e5 100644 --- a/be/src/vec/functions/multiply.cpp +++ b/be/src/vec/functions/multiply.cpp @@ -34,6 +34,7 @@ struct MultiplyImpl { return static_cast(a) * b; } + template static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b) { return a * b; } diff --git a/be/src/vec/functions/plus.cpp b/be/src/vec/functions/plus.cpp index 9f64070aa2..0da30dfa20 100644 --- a/be/src/vec/functions/plus.cpp +++ b/be/src/vec/functions/plus.cpp @@ -35,6 +35,7 @@ struct PlusImpl { return static_cast(a) + b; } + template static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b) { return a + b; } diff --git a/fe/fe-core/src/main/cup/sql_parser.cup b/fe/fe-core/src/main/cup/sql_parser.cup index 7c19093238..da8d8156a1 100644 --- a/fe/fe-core/src/main/cup/sql_parser.cup +++ b/fe/fe-core/src/main/cup/sql_parser.cup @@ -4735,6 +4735,8 @@ expr ::= function_call_expr ::= function_name:fn_name LPAREN RPAREN {: RESULT = new FunctionCallExpr(fn_name, new ArrayList()); :} + | KW_ADD LPAREN function_params:params RPAREN + {: RESULT = new FunctionCallExpr("add", params); :} | function_name:fn_name LPAREN function_params:params RPAREN {: if ("grouping".equalsIgnoreCase(fn_name.getFunction())) {