456 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			456 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /**
 | |
|  * Copyright (c) 2021 OceanBase
 | |
|  * OceanBase CE is licensed under Mulan PubL v2.
 | |
|  * You can use this software according to the terms and conditions of the Mulan PubL v2.
 | |
|  * You may obtain a copy of Mulan PubL v2 at:
 | |
|  *          http://license.coscl.org.cn/MulanPubL-2.0
 | |
|  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 | |
|  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 | |
|  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 | |
|  * See the Mulan PubL v2 for more details.
 | |
|  */
 | |
| 
 | |
| #include <gtest/gtest.h>
 | |
| #include "sql/test_sql_utils.h"
 | |
| #include "lib/utility/ob_test_util.h"
 | |
| #include "sql/resolver/expr/ob_raw_expr_resolver_impl.h"
 | |
| #include "sql/resolver/expr/ob_raw_expr_util.h"
 | |
| #include "sql/resolver/expr/ob_raw_expr_print_visitor.h"
 | |
| #include "sql/ob_sql_init.h"
 | |
| #include "lib/json/ob_json_print_utils.h"
 | |
| #include <fstream>
 | |
| 
 | |
| // files needed by LLVM
 | |
| #ifdef ID
 | |
| #undef ID
 | |
| #endif
 | |
| #include "llvm/DerivedTypes.h"
 | |
| #include "llvm/ExecutionEngine/ExecutionEngine.h"
 | |
| #include "llvm/ExecutionEngine/JIT.h"
 | |
| #include "llvm/LLVMContext.h"
 | |
| #include "llvm/Module.h"
 | |
| #include "llvm/PassManager.h"
 | |
| #include "llvm/Analysis/Verifier.h"
 | |
| #include "llvm/Analysis/Passes.h"
 | |
| #include "llvm/Target/TargetData.h"
 | |
| #include "llvm/Transforms/Scalar.h"
 | |
| #include "llvm/Support/IRBuilder.h"
 | |
| #include "llvm/Support/TargetSelect.h"
 | |
| #include <cstdio>
 | |
| #include <string>
 | |
| #include <map>
 | |
| #include <vector>
 | |
| 
 | |
| using namespace llvm;
 | |
| using namespace oceanbase::common;
 | |
| using namespace oceanbase::sql;
 | |
| 
 | |
| static FunctionPassManager *TheFPM;
 | |
| static Module *TheModule;
 | |
| ////static std::map<std::string, Value*> NamedValues;
 | |
| static IRBuilder<> Builder(getGlobalContext());
 | |
| static ExecutionEngine *TheExecutionEngine;
 | |
| 
 | |
| static int64_t get_usec()
 | |
| {
 | |
|   struct timeval time_val;
 | |
|   gettimeofday(&time_val, NULL);
 | |
|   return time_val.tv_sec*1000000 + time_val.tv_usec;
 | |
| }
 | |
| 
 | |
| class ExprAST;
 | |
| class TestLLVM: public ::testing::Test
 | |
| {
 | |
| public:
 | |
|   TestLLVM();
 | |
|   virtual ~TestLLVM();
 | |
|   virtual void SetUp();
 | |
|   virtual void TearDown();
 | |
| private:
 | |
| 
 | |
|   DISALLOW_COPY_AND_ASSIGN(TestLLVM);
 | |
| protected:
 | |
|   // function members
 | |
|   int resolve(const char* expr, const char *&json_expr);
 | |
|   int convert(ObRawExpr *raw_expr, ExprAST *&ast);
 | |
| };
 | |
| 
 | |
| 
 | |
| TestLLVM::TestLLVM()
 | |
| {
 | |
| }
 | |
| 
 | |
| TestLLVM::~TestLLVM()
 | |
| {
 | |
| }
 | |
| 
 | |
| void TestLLVM::SetUp()
 | |
| {
 | |
| }
 | |
| 
 | |
| void TestLLVM::TearDown()
 | |
| {
 | |
| }
 | |
| 
 | |
| /// ExprAST - Base class for all expression nodes.
 | |
| class ExprAST {
 | |
| public:
 | |
|   virtual ~ExprAST() {}
 | |
|   virtual Value *Codegen() = 0;
 | |
| };
 | |
| 
 | |
| /// NumberExprAST - Expression class for numeric literals like "1.0".
 | |
| class NumberExprAST : public ExprAST {
 | |
|   double Val;
 | |
| public:
 | |
|   NumberExprAST(double val) : Val(val) {}
 | |
|   virtual Value *Codegen();
 | |
| };
 | |
| 
 | |
| Value *NumberExprAST::Codegen() {
 | |
|   return ConstantFP::get(getGlobalContext(), APFloat(Val));
 | |
| }
 | |
| 
 | |
| ExprAST *Error(const char *Str) { fprintf(stderr, "Error: %s\n", Str);return 0;}
 | |
| Value *ErrorV(const char *Str) { Error(Str); return 0; }
 | |
| 
 | |
| ///// VariableExprAST - Expression class for referencing a variable, like "a".
 | |
| //class VariableExprAST : public ExprAST {
 | |
| //  std::string Name;
 | |
| //public:
 | |
| //  VariableExprAST(const std::string &name) : Name(name) {}
 | |
| //  virtual Value *Codegen();
 | |
| //};
 | |
| 
 | |
| 
 | |
| /// BinaryExprAST - Expression class for a binary operator.
 | |
| class BinaryExprAST : public ExprAST {
 | |
|   char Op;    // ObItemType
 | |
|   ExprAST *LHS, *RHS;
 | |
| public:
 | |
|   BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs)
 | |
|     : Op(op), LHS(lhs), RHS(rhs) {}
 | |
|   void add_child_ast(ExprAST *child);
 | |
|   virtual Value *Codegen();
 | |
| };
 | |
| 
 | |
| Value *BinaryExprAST::Codegen() {
 | |
|   Value *L = LHS->Codegen();
 | |
|   Value *R = RHS->Codegen();
 | |
|   if (L == 0 || R == 0) return 0;
 | |
| //  printf("binary op %d\n", Op);
 | |
|   switch (Op) {
 | |
|   case T_OP_ADD: return Builder.CreateFAdd(L, R, "addtmp");
 | |
|   case T_OP_MINUS: return Builder.CreateFSub(L, R, "subtmp");
 | |
|   case T_OP_MUL: return Builder.CreateFMul(L, R, "multmp");
 | |
|   case T_OP_LT:
 | |
|     L = Builder.CreateFCmpULT(L, R, "cmptmp");
 | |
|     // Convert bool 0/1 to double 0.0 or 1.0
 | |
|     return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
 | |
|                                 "booltmp");
 | |
|   default: return ErrorV("invalid binary operator");
 | |
|   }
 | |
| }
 | |
| 
 | |
| void BinaryExprAST::add_child_ast(ExprAST *child)
 | |
| {
 | |
|   if (LHS) {
 | |
|     RHS = child;
 | |
|   } else {
 | |
|     LHS = child;
 | |
|   }
 | |
| }
 | |
| 
 | |
| 
 | |
| /// PrototypeAST - This class represents the "prototype" for a function,
 | |
| /// which captures its name, and its argument names (thus implicitly the number
 | |
| /// of arguments the function takes).
 | |
| class PrototypeAST {
 | |
|   std::string Name;
 | |
|   std::vector<std::string> Args;
 | |
|   std::map<std::string, Value*> NamedValues;
 | |
| public:
 | |
|   PrototypeAST(const std::string &name, const std::vector<std::string> &args)
 | |
|     : Name(name), Args(args) {}
 | |
| 
 | |
|   Function *Codegen();
 | |
| };
 | |
| 
 | |
| Function *PrototypeAST::Codegen() {
 | |
|   // Make the function type:  double(double,double) etc.
 | |
|   std::vector<Type*> Doubles(Args.size(),
 | |
|                              Type::getDoubleTy(getGlobalContext()));
 | |
|   FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()),
 | |
|                                        Doubles, false);
 | |
| 
 | |
|   Function *F = Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
 | |
| 
 | |
|   // If F conflicted, there was already something named 'Name'.  If it has a
 | |
|   // body, don't allow redefinition or reextern.
 | |
|   if (F->getName() != Name) {
 | |
|     // Delete the one we just made and get the existing one.
 | |
|     F->eraseFromParent();
 | |
|     F = TheModule->getFunction(Name);
 | |
| 
 | |
|     // If F already has a body, reject this.
 | |
|     if (!F->empty()) {
 | |
|       printf("redefinition of function");
 | |
|       return 0;
 | |
|     }
 | |
| 
 | |
|     // If F took a different number of args, reject.
 | |
|     if (F->arg_size() != Args.size()) {
 | |
|       printf("redefinition of function with different # args");
 | |
|       return 0;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   // Set names for all arguments.
 | |
|   unsigned Idx = 0;
 | |
|   for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
 | |
|        ++AI, ++Idx) {
 | |
|     AI->setName(Args[Idx]);
 | |
| 
 | |
|     // Add arguments to variable symbol table.
 | |
|     NamedValues[Args[Idx]] = AI;
 | |
|   }
 | |
| 
 | |
|   return F;
 | |
| }
 | |
| 
 | |
| /// FunctionAST - This class represents a function definition itself.
 | |
| class FunctionAST {
 | |
|   PrototypeAST *Proto;
 | |
|   ExprAST *Body;
 | |
| public:
 | |
|   FunctionAST(PrototypeAST *proto, ExprAST *body)
 | |
|     : Proto(proto), Body(body) {}
 | |
| 
 | |
|   Function *Codegen();
 | |
| };
 | |
| 
 | |
| Function *FunctionAST::Codegen() {
 | |
| //  NamedValues.clear();
 | |
| 
 | |
|   Function *TheFunction = Proto->Codegen();
 | |
|   if (TheFunction == 0)
 | |
|     return 0;
 | |
| 
 | |
|   // Create a new basic block to start insertion into.
 | |
|   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
 | |
|   Builder.SetInsertPoint(BB);
 | |
| 
 | |
|   if (Value *RetVal = Body->Codegen()) {
 | |
|     // Finish off the function.
 | |
|     Builder.CreateRet(RetVal);
 | |
| 
 | |
|     // Validate the generated code, checking for consistency.
 | |
|     verifyFunction(*TheFunction);
 | |
| 
 | |
|     // Optimize the function.
 | |
|     TheFPM->run(*TheFunction);
 | |
| 
 | |
|     return TheFunction;
 | |
|   }
 | |
| 
 | |
|   // Error reading body, remove function.
 | |
|   TheFunction->eraseFromParent();
 | |
|   return 0;
 | |
| }
 | |
| 
 | |
| int TestLLVM::resolve(const char* expr, const char *&json_expr)
 | |
| {
 | |
|   int ret = OB_SUCCESS;
 | |
|   ObArray<ObRawExpr*> expr_store;
 | |
|   ObArray<ObQualifiedName> columns;
 | |
|   ObArray<ObVarInfo> sys_vars;
 | |
|   ObArray<ObSubQueryInfo> sub_query_info;
 | |
|   const char* expr_str = expr;
 | |
|   ObArenaAllocator allocator(ObModIds::TEST);
 | |
|   ObTimeZoneInfo tz_info;
 | |
|   ObNameCaseMode case_mode = OB_NAME_CASE_INVALID;
 | |
|   ObRawExprFactory expr_factory(allocator);
 | |
|   ObExprResolveContext ctx(expr_factory, &tz_info, case_mode);
 | |
|   ctx.connection_charset_ = ObCharset::get_default_charset();
 | |
|   ctx.dest_collation_ = ObCharset::get_default_collation(ctx.connection_charset_);
 | |
|   ctx.is_extract_param_type_ = false;
 | |
|   ObRawExpr *raw_expr = NULL;
 | |
|   ObArray<ObAggFunRawExpr*> aggr_exprs;
 | |
|   ObArray<ObWinFunRawExpr*> win_exprs;
 | |
|   ObArray<ObUDFInfo> udf_info;
 | |
|   if (OB_FAIL(ObRawExprUtils::make_raw_expr_from_str(expr_str,
 | |
|                                                      strlen(expr_str),
 | |
|                                                      ctx,
 | |
|                                                      raw_expr,
 | |
|                                                      columns,
 | |
|                                                      sys_vars,
 | |
|                                                      &sub_query_info,
 | |
|                                                      aggr_exprs,
 | |
|                                                      win_exprs,
 | |
|                                                      udf_info)))
 | |
|   {
 | |
|     printf("error");
 | |
|   }
 | |
| 
 | |
|   _OB_LOG(DEBUG, "================================================================");
 | |
|   _OB_LOG(DEBUG, "%s", expr);
 | |
|   _OB_LOG(DEBUG, "%s", CSJ(raw_expr));
 | |
|   if (OB_FAIL(raw_expr->extract_info())) {
 | |
|     printf("error");
 | |
|   }
 | |
|   //OK(raw_expr->deduce_type());
 | |
|   json_expr = CSJ(raw_expr);
 | |
| 
 | |
| 
 | |
|   ExprAST *ast = NULL;
 | |
|   if (OB_FAIL(convert(raw_expr, ast))) {
 | |
|     printf("fail to convert");
 | |
|   } else {
 | |
|     InitializeNativeTarget();
 | |
|     LLVMContext &Context = getGlobalContext();
 | |
|     // Make the module, which holds all the code.
 | |
|     TheModule = new Module("my cool jit", Context);
 | |
| 
 | |
|     // Create the JIT.  This takes ownership of the module.
 | |
|     std::string ErrStr;
 | |
|     TheExecutionEngine = EngineBuilder(TheModule).setErrorStr(&ErrStr).create();
 | |
|     if (!TheExecutionEngine) {
 | |
|       printf("Could not create ExecutionEngine\n");
 | |
|       std::cout << "ErrStr : " << ErrStr << std::endl;
 | |
|       OB_ASSERT(0);
 | |
|     } else {
 | |
| 
 | |
| 
 | |
|       FunctionPassManager OurFPM(TheModule);
 | |
| 
 | |
|       // Set up the optimizer pipeline.  Start with registering info about how the
 | |
|       // target lays out data structures.
 | |
|       OurFPM.add(new TargetData(*TheExecutionEngine->getTargetData()));
 | |
|       // Provide basic AliasAnalysis support for GVN.
 | |
|       OurFPM.add(createBasicAliasAnalysisPass());
 | |
|       // Do simple "peephole" optimizations and bit-twiddling optzns.
 | |
|       OurFPM.add(createInsstructionCombiningPass());
 | |
|       // Reassociate expressions.
 | |
|       OurFPM.add(createReassociatePass());
 | |
|       // Eliminate Common SubExpressions.
 | |
|       OurFPM.add(createGVNPass());
 | |
|       // Simplify the control flow graph (deleting unreachable blocks, etc).
 | |
|       OurFPM.add(createCFGSimplificationPass());
 | |
| 
 | |
|       OurFPM.doInitialization();
 | |
| 
 | |
|       // Set the global so the code gen can use this.
 | |
|       TheFPM = &OurFPM;
 | |
| 
 | |
| 
 | |
|       PrototypeAST *Proto = new PrototypeAST("", std::vector<std::string>());
 | |
|       FunctionAST *func = new FunctionAST(Proto, ast);
 | |
| 
 | |
|       if (!func) {
 | |
|       } else if (Function *LF = func->Codegen()) {
 | |
|         // JIT the function, returning a function pointer.
 | |
|         int64_t i, j, m, n;
 | |
|         i = get_usec();
 | |
|         void *FPtr = TheExecutionEngine->getPointerToFunction(LF);
 | |
|         j = get_usec();
 | |
|         printf("JIT in %ld us\n", j - i);
 | |
|         // Cast it to the right type (takes no arguments, returns a double) so we
 | |
|         // can call it as a native function.
 | |
|         double (*FP)() = (double (*)())(intptr_t)FPtr;
 | |
|         m = get_usec();
 | |
|         for (int64_t k = 0; k < 10000000; k++) {
 | |
|           FP();
 | |
|         }
 | |
|         n = get_usec();
 | |
|         printf("Evaluated to %f\n in %ld us\n", FP(), n-m);
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   return ret;
 | |
| }
 | |
| 
 | |
| int TestLLVM::convert(ObRawExpr *raw_expr, ExprAST *&ast)
 | |
| {
 | |
|   int ret = OB_SUCCESS;
 | |
|   switch (raw_expr->get_expr_class()) {
 | |
|   case ObRawExpr::EXPR_CONST: {
 | |
|     ObConstRawExpr *const_expr = static_cast<ObConstRawExpr*> (raw_expr);
 | |
|     ast = new NumberExprAST(static_cast<double> (const_expr->get_value().get_int()));
 | |
|   }
 | |
|     break;
 | |
|   case ObRawExpr::EXPR_OPERATOR: {
 | |
|     ast = new BinaryExprAST(raw_expr->get_expr_type(), NULL, NULL);
 | |
|   }
 | |
|     break;
 | |
|   default:
 | |
|     OB_ASSERT(0);
 | |
|     break;
 | |
|   }
 | |
| 
 | |
|   if (raw_expr->get_expr_class() == ObRawExpr::EXPR_CONST) {
 | |
|     // do nothing
 | |
|   } else {
 | |
|     ObOpRawExpr *op_expr = static_cast<ObOpRawExpr *> (raw_expr);
 | |
|     BinaryExprAST *binary_ast = static_cast<BinaryExprAST *> (ast);
 | |
|     for (int64_t i = 0; OB_SUCC(ret) && i < op_expr->get_param_count(); i++) {
 | |
|       ExprAST *new_ast = NULL;
 | |
|       if (OB_FAIL(convert(op_expr->get_param_exprs().at(i), new_ast))) {
 | |
|         printf("fail to convert");
 | |
|       } else {
 | |
|         binary_ast->add_child_ast(new_ast);
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   return ret;
 | |
| }
 | |
| 
 | |
| TEST_F(TestLLVM, all)
 | |
| {
 | |
|   const char* json_expr = NULL;
 | |
|   int ret = OB_SUCCESS;
 | |
|   if (OB_SUCCESS != (ret = resolve("1+2", json_expr))) {
 | |
|     printf("fail to resolve\n");
 | |
|   } else if (OB_SUCCESS != (ret = resolve("2-1", json_expr))) {
 | |
|     printf("fail to resolve\n");
 | |
|   } else if (OB_SUCCESS != (ret = resolve("1345435+ 1232132", json_expr))) {
 | |
|     printf("fail to resolve\n");
 | |
|   } else if (OB_SUCCESS != (ret = resolve("6/3", json_expr))) {
 | |
|   }
 | |
| //  static const char* test_file = "./expr/test_raw_expr_resolver.test";
 | |
| //  static const char* tmp_file = "./expr/test_raw_expr_resolver.tmp";
 | |
| //  static const char* result_file = "./expr/test_raw_expr_resolver.result";
 | |
| //
 | |
| //  std::ifstream if_tests(test_file);
 | |
| //  ASSERT_TRUE(if_tests.is_open());
 | |
| //  std::string line;
 | |
| //  const char* json_expr = NULL;
 | |
| //  std::ofstream of_result(tmp_file);
 | |
| //  ASSERT_TRUE(of_result.is_open());
 | |
| //  int64_t case_id = 0;
 | |
| //  while (std::getline(if_tests, line)) {
 | |
| //    of_result << '[' << case_id++ << "] " << line << std::endl;
 | |
| //    resolve(line.c_str(), json_expr);
 | |
| //    of_result << json_expr << std::endl;
 | |
| //  }
 | |
| //  of_result.close();
 | |
| //  // verify results
 | |
| //  fprintf(stderr, "If tests failed, use `diff %s %s' to see the differences. \n", result_file, tmp_file);
 | |
| //  std::ifstream if_result(tmp_file);
 | |
| //  ASSERT_TRUE(if_result.is_open());
 | |
| //  std::istream_iterator<std::string> it_result(if_result);
 | |
| //  std::ifstream if_expected(result_file);
 | |
| //  ASSERT_TRUE(if_expected.is_open());
 | |
| //  std::istream_iterator<std::string> it_expected(if_expected);
 | |
| //  ASSERT_TRUE(std::equal(it_result, std::istream_iterator<std::string>(), it_expected));
 | |
| //  std::remove(tmp_file);
 | |
| }
 | |
| 
 | |
| int main(int argc, char **argv)
 | |
| {
 | |
|   ::testing::InitGoogleTest(&argc,argv);
 | |
|   init_sql_factories();
 | |
|   return RUN_ALL_TESTS();
 | |
| }
 | 
