/** * Copyright (c) 2021 OceanBase * OceanBase CE is licensed under Mulan PubL v2. * You can use this software according to the terms and conditions of the Mulan PubL v2. * You may obtain a copy of Mulan PubL v2 at: * http://license.coscl.org.cn/MulanPubL-2.0 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. * See the Mulan PubL v2 for more details. */ #include #include "sql/test_sql_utils.h" #include "lib/utility/ob_test_util.h" #include "sql/resolver/expr/ob_raw_expr_resolver_impl.h" #include "sql/resolver/expr/ob_raw_expr_util.h" #include "sql/resolver/expr/ob_raw_expr_print_visitor.h" #include "sql/ob_sql_init.h" #include "lib/json/ob_json_print_utils.h" #include // files needed by LLVM #ifdef ID #undef ID #endif #include "llvm/DerivedTypes.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/JIT.h" #include "llvm/LLVMContext.h" #include "llvm/Module.h" #include "llvm/PassManager.h" #include "llvm/Analysis/Verifier.h" #include "llvm/Analysis/Passes.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Support/IRBuilder.h" #include "llvm/Support/TargetSelect.h" #include #include #include #include using namespace llvm; using namespace oceanbase::common; using namespace oceanbase::sql; static FunctionPassManager *TheFPM; static Module *TheModule; ////static std::map NamedValues; static IRBuilder<> Builder(getGlobalContext()); static ExecutionEngine *TheExecutionEngine; static int64_t get_usec() { struct timeval time_val; gettimeofday(&time_val, NULL); return time_val.tv_sec*1000000 + time_val.tv_usec; } class ExprAST; class TestLLVM: public ::testing::Test { public: TestLLVM(); virtual ~TestLLVM(); virtual void SetUp(); virtual void TearDown(); private: DISALLOW_COPY_AND_ASSIGN(TestLLVM); protected: // function members int resolve(const char* expr, const char *&json_expr); int convert(ObRawExpr *raw_expr, ExprAST *&ast); }; TestLLVM::TestLLVM() { } TestLLVM::~TestLLVM() { } void TestLLVM::SetUp() { } void TestLLVM::TearDown() { } /// ExprAST - Base class for all expression nodes. class ExprAST { public: virtual ~ExprAST() {} virtual Value *Codegen() = 0; }; /// NumberExprAST - Expression class for numeric literals like "1.0". class NumberExprAST : public ExprAST { double Val; public: NumberExprAST(double val) : Val(val) {} virtual Value *Codegen(); }; Value *NumberExprAST::Codegen() { return ConstantFP::get(getGlobalContext(), APFloat(Val)); } ExprAST *Error(const char *Str) { fprintf(stderr, "Error: %s\n", Str);return 0;} Value *ErrorV(const char *Str) { Error(Str); return 0; } ///// VariableExprAST - Expression class for referencing a variable, like "a". //class VariableExprAST : public ExprAST { // std::string Name; //public: // VariableExprAST(const std::string &name) : Name(name) {} // virtual Value *Codegen(); //}; /// BinaryExprAST - Expression class for a binary operator. class BinaryExprAST : public ExprAST { char Op; // ObItemType ExprAST *LHS, *RHS; public: BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs) : Op(op), LHS(lhs), RHS(rhs) {} void add_child_ast(ExprAST *child); virtual Value *Codegen(); }; Value *BinaryExprAST::Codegen() { Value *L = LHS->Codegen(); Value *R = RHS->Codegen(); if (L == 0 || R == 0) return 0; // printf("binary op %d\n", Op); switch (Op) { case T_OP_ADD: return Builder.CreateFAdd(L, R, "addtmp"); case T_OP_MINUS: return Builder.CreateFSub(L, R, "subtmp"); case T_OP_MUL: return Builder.CreateFMul(L, R, "multmp"); case T_OP_LT: L = Builder.CreateFCmpULT(L, R, "cmptmp"); // Convert bool 0/1 to double 0.0 or 1.0 return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()), "booltmp"); default: return ErrorV("invalid binary operator"); } } void BinaryExprAST::add_child_ast(ExprAST *child) { if (LHS) { RHS = child; } else { LHS = child; } } /// PrototypeAST - This class represents the "prototype" for a function, /// which captures its name, and its argument names (thus implicitly the number /// of arguments the function takes). class PrototypeAST { std::string Name; std::vector Args; std::map NamedValues; public: PrototypeAST(const std::string &name, const std::vector &args) : Name(name), Args(args) {} Function *Codegen(); }; Function *PrototypeAST::Codegen() { // Make the function type: double(double,double) etc. std::vector Doubles(Args.size(), Type::getDoubleTy(getGlobalContext())); FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false); Function *F = Function::Create(FT, Function::ExternalLinkage, Name, TheModule); // If F conflicted, there was already something named 'Name'. If it has a // body, don't allow redefinition or reextern. if (F->getName() != Name) { // Delete the one we just made and get the existing one. F->eraseFromParent(); F = TheModule->getFunction(Name); // If F already has a body, reject this. if (!F->empty()) { printf("redefinition of function"); return 0; } // If F took a different number of args, reject. if (F->arg_size() != Args.size()) { printf("redefinition of function with different # args"); return 0; } } // Set names for all arguments. unsigned Idx = 0; for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size(); ++AI, ++Idx) { AI->setName(Args[Idx]); // Add arguments to variable symbol table. NamedValues[Args[Idx]] = AI; } return F; } /// FunctionAST - This class represents a function definition itself. class FunctionAST { PrototypeAST *Proto; ExprAST *Body; public: FunctionAST(PrototypeAST *proto, ExprAST *body) : Proto(proto), Body(body) {} Function *Codegen(); }; Function *FunctionAST::Codegen() { // NamedValues.clear(); Function *TheFunction = Proto->Codegen(); if (TheFunction == 0) return 0; // Create a new basic block to start insertion into. BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction); Builder.SetInsertPoint(BB); if (Value *RetVal = Body->Codegen()) { // Finish off the function. Builder.CreateRet(RetVal); // Validate the generated code, checking for consistency. verifyFunction(*TheFunction); // Optimize the function. TheFPM->run(*TheFunction); return TheFunction; } // Error reading body, remove function. TheFunction->eraseFromParent(); return 0; } int TestLLVM::resolve(const char* expr, const char *&json_expr) { int ret = OB_SUCCESS; ObArray expr_store; ObArray columns; ObArray sys_vars; ObArray sub_query_info; const char* expr_str = expr; ObArenaAllocator allocator(ObModIds::TEST); ObTimeZoneInfo tz_info; ObNameCaseMode case_mode = OB_NAME_CASE_INVALID; ObRawExprFactory expr_factory(allocator); ObExprResolveContext ctx(expr_factory, &tz_info, case_mode); ctx.connection_charset_ = ObCharset::get_default_charset(); ctx.dest_collation_ = ObCharset::get_default_collation(ctx.connection_charset_); ctx.is_extract_param_type_ = false; ObRawExpr *raw_expr = NULL; ObArray aggr_exprs; ObArray win_exprs; ObArray udf_info; if (OB_FAIL(ObRawExprUtils::make_raw_expr_from_str(expr_str, strlen(expr_str), ctx, raw_expr, columns, sys_vars, &sub_query_info, aggr_exprs, win_exprs, udf_info))) { printf("error"); } _OB_LOG(DEBUG, "================================================================"); _OB_LOG(DEBUG, "%s", expr); _OB_LOG(DEBUG, "%s", CSJ(raw_expr)); if (OB_FAIL(raw_expr->extract_info())) { printf("error"); } //OK(raw_expr->deduce_type()); json_expr = CSJ(raw_expr); ExprAST *ast = NULL; if (OB_FAIL(convert(raw_expr, ast))) { printf("fail to convert"); } else { InitializeNativeTarget(); LLVMContext &Context = getGlobalContext(); // Make the module, which holds all the code. TheModule = new Module("my cool jit", Context); // Create the JIT. This takes ownership of the module. std::string ErrStr; TheExecutionEngine = EngineBuilder(TheModule).setErrorStr(&ErrStr).create(); if (!TheExecutionEngine) { printf("Could not create ExecutionEngine\n"); std::cout << "ErrStr : " << ErrStr << std::endl; OB_ASSERT(0); } else { FunctionPassManager OurFPM(TheModule); // Set up the optimizer pipeline. Start with registering info about how the // target lays out data structures. OurFPM.add(new TargetData(*TheExecutionEngine->getTargetData())); // Provide basic AliasAnalysis support for GVN. OurFPM.add(createBasicAliasAnalysisPass()); // Do simple "peephole" optimizations and bit-twiddling optzns. OurFPM.add(createInsstructionCombiningPass()); // Reassociate expressions. OurFPM.add(createReassociatePass()); // Eliminate Common SubExpressions. OurFPM.add(createGVNPass()); // Simplify the control flow graph (deleting unreachable blocks, etc). OurFPM.add(createCFGSimplificationPass()); OurFPM.doInitialization(); // Set the global so the code gen can use this. TheFPM = &OurFPM; PrototypeAST *Proto = new PrototypeAST("", std::vector()); FunctionAST *func = new FunctionAST(Proto, ast); if (!func) { } else if (Function *LF = func->Codegen()) { // JIT the function, returning a function pointer. int64_t i, j, m, n; i = get_usec(); void *FPtr = TheExecutionEngine->getPointerToFunction(LF); j = get_usec(); printf("JIT in %ld us\n", j - i); // Cast it to the right type (takes no arguments, returns a double) so we // can call it as a native function. double (*FP)() = (double (*)())(intptr_t)FPtr; m = get_usec(); for (int64_t k = 0; k < 10000000; k++) { FP(); } n = get_usec(); printf("Evaluated to %f\n in %ld us\n", FP(), n-m); } } } return ret; } int TestLLVM::convert(ObRawExpr *raw_expr, ExprAST *&ast) { int ret = OB_SUCCESS; switch (raw_expr->get_expr_class()) { case ObRawExpr::EXPR_CONST: { ObConstRawExpr *const_expr = static_cast (raw_expr); ast = new NumberExprAST(static_cast (const_expr->get_value().get_int())); } break; case ObRawExpr::EXPR_OPERATOR: { ast = new BinaryExprAST(raw_expr->get_expr_type(), NULL, NULL); } break; default: OB_ASSERT(0); break; } if (raw_expr->get_expr_class() == ObRawExpr::EXPR_CONST) { // do nothing } else { ObOpRawExpr *op_expr = static_cast (raw_expr); BinaryExprAST *binary_ast = static_cast (ast); for (int64_t i = 0; OB_SUCC(ret) && i < op_expr->get_param_count(); i++) { ExprAST *new_ast = NULL; if (OB_FAIL(convert(op_expr->get_param_exprs().at(i), new_ast))) { printf("fail to convert"); } else { binary_ast->add_child_ast(new_ast); } } } return ret; } TEST_F(TestLLVM, all) { const char* json_expr = NULL; int ret = OB_SUCCESS; if (OB_SUCCESS != (ret = resolve("1+2", json_expr))) { printf("fail to resolve\n"); } else if (OB_SUCCESS != (ret = resolve("2-1", json_expr))) { printf("fail to resolve\n"); } else if (OB_SUCCESS != (ret = resolve("1345435+ 1232132", json_expr))) { printf("fail to resolve\n"); } else if (OB_SUCCESS != (ret = resolve("6/3", json_expr))) { } // static const char* test_file = "./expr/test_raw_expr_resolver.test"; // static const char* tmp_file = "./expr/test_raw_expr_resolver.tmp"; // static const char* result_file = "./expr/test_raw_expr_resolver.result"; // // std::ifstream if_tests(test_file); // ASSERT_TRUE(if_tests.is_open()); // std::string line; // const char* json_expr = NULL; // std::ofstream of_result(tmp_file); // ASSERT_TRUE(of_result.is_open()); // int64_t case_id = 0; // while (std::getline(if_tests, line)) { // of_result << '[' << case_id++ << "] " << line << std::endl; // resolve(line.c_str(), json_expr); // of_result << json_expr << std::endl; // } // of_result.close(); // // verify results // fprintf(stderr, "If tests failed, use `diff %s %s' to see the differences. \n", result_file, tmp_file); // std::ifstream if_result(tmp_file); // ASSERT_TRUE(if_result.is_open()); // std::istream_iterator it_result(if_result); // std::ifstream if_expected(result_file); // ASSERT_TRUE(if_expected.is_open()); // std::istream_iterator it_expected(if_expected); // ASSERT_TRUE(std::equal(it_result, std::istream_iterator(), it_expected)); // std::remove(tmp_file); } int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc,argv); init_sql_factories(); return RUN_ALL_TESTS(); }