456 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			456 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
/**
 | 
						|
 * Copyright (c) 2021 OceanBase
 | 
						|
 * OceanBase CE is licensed under Mulan PubL v2.
 | 
						|
 * You can use this software according to the terms and conditions of the Mulan PubL v2.
 | 
						|
 * You may obtain a copy of Mulan PubL v2 at:
 | 
						|
 *          http://license.coscl.org.cn/MulanPubL-2.0
 | 
						|
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 | 
						|
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 | 
						|
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 | 
						|
 * See the Mulan PubL v2 for more details.
 | 
						|
 */
 | 
						|
 | 
						|
#include <gtest/gtest.h>
 | 
						|
#include "sql/test_sql_utils.h"
 | 
						|
#include "lib/utility/ob_test_util.h"
 | 
						|
#include "sql/resolver/expr/ob_raw_expr_resolver_impl.h"
 | 
						|
#include "sql/resolver/expr/ob_raw_expr_util.h"
 | 
						|
#include "sql/resolver/expr/ob_raw_expr_print_visitor.h"
 | 
						|
#include "sql/ob_sql_init.h"
 | 
						|
#include "lib/json/ob_json_print_utils.h"
 | 
						|
#include <fstream>
 | 
						|
 | 
						|
// files needed by LLVM
 | 
						|
#ifdef ID
 | 
						|
#undef ID
 | 
						|
#endif
 | 
						|
#include "llvm/DerivedTypes.h"
 | 
						|
#include "llvm/ExecutionEngine/ExecutionEngine.h"
 | 
						|
#include "llvm/ExecutionEngine/JIT.h"
 | 
						|
#include "llvm/LLVMContext.h"
 | 
						|
#include "llvm/Module.h"
 | 
						|
#include "llvm/PassManager.h"
 | 
						|
#include "llvm/Analysis/Verifier.h"
 | 
						|
#include "llvm/Analysis/Passes.h"
 | 
						|
#include "llvm/Target/TargetData.h"
 | 
						|
#include "llvm/Transforms/Scalar.h"
 | 
						|
#include "llvm/Support/IRBuilder.h"
 | 
						|
#include "llvm/Support/TargetSelect.h"
 | 
						|
#include <cstdio>
 | 
						|
#include <string>
 | 
						|
#include <map>
 | 
						|
#include <vector>
 | 
						|
 | 
						|
using namespace llvm;
 | 
						|
using namespace oceanbase::common;
 | 
						|
using namespace oceanbase::sql;
 | 
						|
 | 
						|
static FunctionPassManager *TheFPM;
 | 
						|
static Module *TheModule;
 | 
						|
////static std::map<std::string, Value*> NamedValues;
 | 
						|
static IRBuilder<> Builder(getGlobalContext());
 | 
						|
static ExecutionEngine *TheExecutionEngine;
 | 
						|
 | 
						|
static int64_t get_usec()
 | 
						|
{
 | 
						|
  struct timeval time_val;
 | 
						|
  gettimeofday(&time_val, NULL);
 | 
						|
  return time_val.tv_sec*1000000 + time_val.tv_usec;
 | 
						|
}
 | 
						|
 | 
						|
class ExprAST;
 | 
						|
class TestLLVM: public ::testing::Test
 | 
						|
{
 | 
						|
public:
 | 
						|
  TestLLVM();
 | 
						|
  virtual ~TestLLVM();
 | 
						|
  virtual void SetUp();
 | 
						|
  virtual void TearDown();
 | 
						|
private:
 | 
						|
 | 
						|
  DISALLOW_COPY_AND_ASSIGN(TestLLVM);
 | 
						|
protected:
 | 
						|
  // function members
 | 
						|
  int resolve(const char* expr, const char *&json_expr);
 | 
						|
  int convert(ObRawExpr *raw_expr, ExprAST *&ast);
 | 
						|
};
 | 
						|
 | 
						|
 | 
						|
TestLLVM::TestLLVM()
 | 
						|
{
 | 
						|
}
 | 
						|
 | 
						|
TestLLVM::~TestLLVM()
 | 
						|
{
 | 
						|
}
 | 
						|
 | 
						|
void TestLLVM::SetUp()
 | 
						|
{
 | 
						|
}
 | 
						|
 | 
						|
void TestLLVM::TearDown()
 | 
						|
{
 | 
						|
}
 | 
						|
 | 
						|
/// ExprAST - Base class for all expression nodes.
 | 
						|
class ExprAST {
 | 
						|
public:
 | 
						|
  virtual ~ExprAST() {}
 | 
						|
  virtual Value *Codegen() = 0;
 | 
						|
};
 | 
						|
 | 
						|
/// NumberExprAST - Expression class for numeric literals like "1.0".
 | 
						|
class NumberExprAST : public ExprAST {
 | 
						|
  double Val;
 | 
						|
public:
 | 
						|
  NumberExprAST(double val) : Val(val) {}
 | 
						|
  virtual Value *Codegen();
 | 
						|
};
 | 
						|
 | 
						|
Value *NumberExprAST::Codegen() {
 | 
						|
  return ConstantFP::get(getGlobalContext(), APFloat(Val));
 | 
						|
}
 | 
						|
 | 
						|
ExprAST *Error(const char *Str) { fprintf(stderr, "Error: %s\n", Str);return 0;}
 | 
						|
Value *ErrorV(const char *Str) { Error(Str); return 0; }
 | 
						|
 | 
						|
///// VariableExprAST - Expression class for referencing a variable, like "a".
 | 
						|
//class VariableExprAST : public ExprAST {
 | 
						|
//  std::string Name;
 | 
						|
//public:
 | 
						|
//  VariableExprAST(const std::string &name) : Name(name) {}
 | 
						|
//  virtual Value *Codegen();
 | 
						|
//};
 | 
						|
 | 
						|
 | 
						|
/// BinaryExprAST - Expression class for a binary operator.
 | 
						|
class BinaryExprAST : public ExprAST {
 | 
						|
  char Op;    // ObItemType
 | 
						|
  ExprAST *LHS, *RHS;
 | 
						|
public:
 | 
						|
  BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs)
 | 
						|
    : Op(op), LHS(lhs), RHS(rhs) {}
 | 
						|
  void add_child_ast(ExprAST *child);
 | 
						|
  virtual Value *Codegen();
 | 
						|
};
 | 
						|
 | 
						|
Value *BinaryExprAST::Codegen() {
 | 
						|
  Value *L = LHS->Codegen();
 | 
						|
  Value *R = RHS->Codegen();
 | 
						|
  if (L == 0 || R == 0) return 0;
 | 
						|
//  printf("binary op %d\n", Op);
 | 
						|
  switch (Op) {
 | 
						|
  case T_OP_ADD: return Builder.CreateFAdd(L, R, "addtmp");
 | 
						|
  case T_OP_MINUS: return Builder.CreateFSub(L, R, "subtmp");
 | 
						|
  case T_OP_MUL: return Builder.CreateFMul(L, R, "multmp");
 | 
						|
  case T_OP_LT:
 | 
						|
    L = Builder.CreateFCmpULT(L, R, "cmptmp");
 | 
						|
    // Convert bool 0/1 to double 0.0 or 1.0
 | 
						|
    return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
 | 
						|
                                "booltmp");
 | 
						|
  default: return ErrorV("invalid binary operator");
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
void BinaryExprAST::add_child_ast(ExprAST *child)
 | 
						|
{
 | 
						|
  if (LHS) {
 | 
						|
    RHS = child;
 | 
						|
  } else {
 | 
						|
    LHS = child;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
/// PrototypeAST - This class represents the "prototype" for a function,
 | 
						|
/// which captures its name, and its argument names (thus implicitly the number
 | 
						|
/// of arguments the function takes).
 | 
						|
class PrototypeAST {
 | 
						|
  std::string Name;
 | 
						|
  std::vector<std::string> Args;
 | 
						|
  std::map<std::string, Value*> NamedValues;
 | 
						|
public:
 | 
						|
  PrototypeAST(const std::string &name, const std::vector<std::string> &args)
 | 
						|
    : Name(name), Args(args) {}
 | 
						|
 | 
						|
  Function *Codegen();
 | 
						|
};
 | 
						|
 | 
						|
Function *PrototypeAST::Codegen() {
 | 
						|
  // Make the function type:  double(double,double) etc.
 | 
						|
  std::vector<Type*> Doubles(Args.size(),
 | 
						|
                             Type::getDoubleTy(getGlobalContext()));
 | 
						|
  FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()),
 | 
						|
                                       Doubles, false);
 | 
						|
 | 
						|
  Function *F = Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
 | 
						|
 | 
						|
  // If F conflicted, there was already something named 'Name'.  If it has a
 | 
						|
  // body, don't allow redefinition or reextern.
 | 
						|
  if (F->getName() != Name) {
 | 
						|
    // Delete the one we just made and get the existing one.
 | 
						|
    F->eraseFromParent();
 | 
						|
    F = TheModule->getFunction(Name);
 | 
						|
 | 
						|
    // If F already has a body, reject this.
 | 
						|
    if (!F->empty()) {
 | 
						|
      printf("redefinition of function");
 | 
						|
      return 0;
 | 
						|
    }
 | 
						|
 | 
						|
    // If F took a different number of args, reject.
 | 
						|
    if (F->arg_size() != Args.size()) {
 | 
						|
      printf("redefinition of function with different # args");
 | 
						|
      return 0;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  // Set names for all arguments.
 | 
						|
  unsigned Idx = 0;
 | 
						|
  for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
 | 
						|
       ++AI, ++Idx) {
 | 
						|
    AI->setName(Args[Idx]);
 | 
						|
 | 
						|
    // Add arguments to variable symbol table.
 | 
						|
    NamedValues[Args[Idx]] = AI;
 | 
						|
  }
 | 
						|
 | 
						|
  return F;
 | 
						|
}
 | 
						|
 | 
						|
/// FunctionAST - This class represents a function definition itself.
 | 
						|
class FunctionAST {
 | 
						|
  PrototypeAST *Proto;
 | 
						|
  ExprAST *Body;
 | 
						|
public:
 | 
						|
  FunctionAST(PrototypeAST *proto, ExprAST *body)
 | 
						|
    : Proto(proto), Body(body) {}
 | 
						|
 | 
						|
  Function *Codegen();
 | 
						|
};
 | 
						|
 | 
						|
Function *FunctionAST::Codegen() {
 | 
						|
//  NamedValues.clear();
 | 
						|
 | 
						|
  Function *TheFunction = Proto->Codegen();
 | 
						|
  if (TheFunction == 0)
 | 
						|
    return 0;
 | 
						|
 | 
						|
  // Create a new basic block to start insertion into.
 | 
						|
  BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
 | 
						|
  Builder.SetInsertPoint(BB);
 | 
						|
 | 
						|
  if (Value *RetVal = Body->Codegen()) {
 | 
						|
    // Finish off the function.
 | 
						|
    Builder.CreateRet(RetVal);
 | 
						|
 | 
						|
    // Validate the generated code, checking for consistency.
 | 
						|
    verifyFunction(*TheFunction);
 | 
						|
 | 
						|
    // Optimize the function.
 | 
						|
    TheFPM->run(*TheFunction);
 | 
						|
 | 
						|
    return TheFunction;
 | 
						|
  }
 | 
						|
 | 
						|
  // Error reading body, remove function.
 | 
						|
  TheFunction->eraseFromParent();
 | 
						|
  return 0;
 | 
						|
}
 | 
						|
 | 
						|
int TestLLVM::resolve(const char* expr, const char *&json_expr)
 | 
						|
{
 | 
						|
  int ret = OB_SUCCESS;
 | 
						|
  ObArray<ObRawExpr*> expr_store;
 | 
						|
  ObArray<ObQualifiedName> columns;
 | 
						|
  ObArray<ObVarInfo> sys_vars;
 | 
						|
  ObArray<ObSubQueryInfo> sub_query_info;
 | 
						|
  const char* expr_str = expr;
 | 
						|
  ObArenaAllocator allocator(ObModIds::TEST);
 | 
						|
  ObTimeZoneInfo tz_info;
 | 
						|
  ObNameCaseMode case_mode = OB_NAME_CASE_INVALID;
 | 
						|
  ObRawExprFactory expr_factory(allocator);
 | 
						|
  ObExprResolveContext ctx(expr_factory, &tz_info, case_mode);
 | 
						|
  ctx.connection_charset_ = ObCharset::get_default_charset();
 | 
						|
  ctx.dest_collation_ = ObCharset::get_default_collation(ctx.connection_charset_);
 | 
						|
  ctx.is_extract_param_type_ = false;
 | 
						|
  ObRawExpr *raw_expr = NULL;
 | 
						|
  ObArray<ObAggFunRawExpr*> aggr_exprs;
 | 
						|
  ObArray<ObWinFunRawExpr*> win_exprs;
 | 
						|
  ObArray<ObUDFInfo> udf_info;
 | 
						|
  if (OB_FAIL(ObRawExprUtils::make_raw_expr_from_str(expr_str,
 | 
						|
                                                     strlen(expr_str),
 | 
						|
                                                     ctx,
 | 
						|
                                                     raw_expr,
 | 
						|
                                                     columns,
 | 
						|
                                                     sys_vars,
 | 
						|
                                                     &sub_query_info,
 | 
						|
                                                     aggr_exprs,
 | 
						|
                                                     win_exprs,
 | 
						|
                                                     udf_info)))
 | 
						|
  {
 | 
						|
    printf("error");
 | 
						|
  }
 | 
						|
 | 
						|
  _OB_LOG(DEBUG, "================================================================");
 | 
						|
  _OB_LOG(DEBUG, "%s", expr);
 | 
						|
  _OB_LOG(DEBUG, "%s", CSJ(raw_expr));
 | 
						|
  if (OB_FAIL(raw_expr->extract_info())) {
 | 
						|
    printf("error");
 | 
						|
  }
 | 
						|
  //OK(raw_expr->deduce_type());
 | 
						|
  json_expr = CSJ(raw_expr);
 | 
						|
 | 
						|
 | 
						|
  ExprAST *ast = NULL;
 | 
						|
  if (OB_FAIL(convert(raw_expr, ast))) {
 | 
						|
    printf("fail to convert");
 | 
						|
  } else {
 | 
						|
    InitializeNativeTarget();
 | 
						|
    LLVMContext &Context = getGlobalContext();
 | 
						|
    // Make the module, which holds all the code.
 | 
						|
    TheModule = new Module("my cool jit", Context);
 | 
						|
 | 
						|
    // Create the JIT.  This takes ownership of the module.
 | 
						|
    std::string ErrStr;
 | 
						|
    TheExecutionEngine = EngineBuilder(TheModule).setErrorStr(&ErrStr).create();
 | 
						|
    if (!TheExecutionEngine) {
 | 
						|
      printf("Could not create ExecutionEngine\n");
 | 
						|
      std::cout << "ErrStr : " << ErrStr << std::endl;
 | 
						|
      OB_ASSERT(0);
 | 
						|
    } else {
 | 
						|
 | 
						|
 | 
						|
      FunctionPassManager OurFPM(TheModule);
 | 
						|
 | 
						|
      // Set up the optimizer pipeline.  Start with registering info about how the
 | 
						|
      // target lays out data structures.
 | 
						|
      OurFPM.add(new TargetData(*TheExecutionEngine->getTargetData()));
 | 
						|
      // Provide basic AliasAnalysis support for GVN.
 | 
						|
      OurFPM.add(createBasicAliasAnalysisPass());
 | 
						|
      // Do simple "peephole" optimizations and bit-twiddling optzns.
 | 
						|
      OurFPM.add(createInsstructionCombiningPass());
 | 
						|
      // Reassociate expressions.
 | 
						|
      OurFPM.add(createReassociatePass());
 | 
						|
      // Eliminate Common SubExpressions.
 | 
						|
      OurFPM.add(createGVNPass());
 | 
						|
      // Simplify the control flow graph (deleting unreachable blocks, etc).
 | 
						|
      OurFPM.add(createCFGSimplificationPass());
 | 
						|
 | 
						|
      OurFPM.doInitialization();
 | 
						|
 | 
						|
      // Set the global so the code gen can use this.
 | 
						|
      TheFPM = &OurFPM;
 | 
						|
 | 
						|
 | 
						|
      PrototypeAST *Proto = new PrototypeAST("", std::vector<std::string>());
 | 
						|
      FunctionAST *func = new FunctionAST(Proto, ast);
 | 
						|
 | 
						|
      if (!func) {
 | 
						|
      } else if (Function *LF = func->Codegen()) {
 | 
						|
        // JIT the function, returning a function pointer.
 | 
						|
        int64_t i, j, m, n;
 | 
						|
        i = get_usec();
 | 
						|
        void *FPtr = TheExecutionEngine->getPointerToFunction(LF);
 | 
						|
        j = get_usec();
 | 
						|
        printf("JIT in %ld us\n", j - i);
 | 
						|
        // Cast it to the right type (takes no arguments, returns a double) so we
 | 
						|
        // can call it as a native function.
 | 
						|
        double (*FP)() = (double (*)())(intptr_t)FPtr;
 | 
						|
        m = get_usec();
 | 
						|
        for (int64_t k = 0; k < 10000000; k++) {
 | 
						|
          FP();
 | 
						|
        }
 | 
						|
        n = get_usec();
 | 
						|
        printf("Evaluated to %f\n in %ld us\n", FP(), n-m);
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  return ret;
 | 
						|
}
 | 
						|
 | 
						|
int TestLLVM::convert(ObRawExpr *raw_expr, ExprAST *&ast)
 | 
						|
{
 | 
						|
  int ret = OB_SUCCESS;
 | 
						|
  switch (raw_expr->get_expr_class()) {
 | 
						|
  case ObRawExpr::EXPR_CONST: {
 | 
						|
    ObConstRawExpr *const_expr = static_cast<ObConstRawExpr*> (raw_expr);
 | 
						|
    ast = new NumberExprAST(static_cast<double> (const_expr->get_value().get_int()));
 | 
						|
  }
 | 
						|
    break;
 | 
						|
  case ObRawExpr::EXPR_OPERATOR: {
 | 
						|
    ast = new BinaryExprAST(raw_expr->get_expr_type(), NULL, NULL);
 | 
						|
  }
 | 
						|
    break;
 | 
						|
  default:
 | 
						|
    OB_ASSERT(0);
 | 
						|
    break;
 | 
						|
  }
 | 
						|
 | 
						|
  if (raw_expr->get_expr_class() == ObRawExpr::EXPR_CONST) {
 | 
						|
    // do nothing
 | 
						|
  } else {
 | 
						|
    ObOpRawExpr *op_expr = static_cast<ObOpRawExpr *> (raw_expr);
 | 
						|
    BinaryExprAST *binary_ast = static_cast<BinaryExprAST *> (ast);
 | 
						|
    for (int64_t i = 0; OB_SUCC(ret) && i < op_expr->get_param_count(); i++) {
 | 
						|
      ExprAST *new_ast = NULL;
 | 
						|
      if (OB_FAIL(convert(op_expr->get_param_exprs().at(i), new_ast))) {
 | 
						|
        printf("fail to convert");
 | 
						|
      } else {
 | 
						|
        binary_ast->add_child_ast(new_ast);
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  return ret;
 | 
						|
}
 | 
						|
 | 
						|
TEST_F(TestLLVM, all)
 | 
						|
{
 | 
						|
  const char* json_expr = NULL;
 | 
						|
  int ret = OB_SUCCESS;
 | 
						|
  if (OB_SUCCESS != (ret = resolve("1+2", json_expr))) {
 | 
						|
    printf("fail to resolve\n");
 | 
						|
  } else if (OB_SUCCESS != (ret = resolve("2-1", json_expr))) {
 | 
						|
    printf("fail to resolve\n");
 | 
						|
  } else if (OB_SUCCESS != (ret = resolve("1345435+ 1232132", json_expr))) {
 | 
						|
    printf("fail to resolve\n");
 | 
						|
  } else if (OB_SUCCESS != (ret = resolve("6/3", json_expr))) {
 | 
						|
  }
 | 
						|
//  static const char* test_file = "./expr/test_raw_expr_resolver.test";
 | 
						|
//  static const char* tmp_file = "./expr/test_raw_expr_resolver.tmp";
 | 
						|
//  static const char* result_file = "./expr/test_raw_expr_resolver.result";
 | 
						|
//
 | 
						|
//  std::ifstream if_tests(test_file);
 | 
						|
//  ASSERT_TRUE(if_tests.is_open());
 | 
						|
//  std::string line;
 | 
						|
//  const char* json_expr = NULL;
 | 
						|
//  std::ofstream of_result(tmp_file);
 | 
						|
//  ASSERT_TRUE(of_result.is_open());
 | 
						|
//  int64_t case_id = 0;
 | 
						|
//  while (std::getline(if_tests, line)) {
 | 
						|
//    of_result << '[' << case_id++ << "] " << line << std::endl;
 | 
						|
//    resolve(line.c_str(), json_expr);
 | 
						|
//    of_result << json_expr << std::endl;
 | 
						|
//  }
 | 
						|
//  of_result.close();
 | 
						|
//  // verify results
 | 
						|
//  fprintf(stderr, "If tests failed, use `diff %s %s' to see the differences. \n", result_file, tmp_file);
 | 
						|
//  std::ifstream if_result(tmp_file);
 | 
						|
//  ASSERT_TRUE(if_result.is_open());
 | 
						|
//  std::istream_iterator<std::string> it_result(if_result);
 | 
						|
//  std::ifstream if_expected(result_file);
 | 
						|
//  ASSERT_TRUE(if_expected.is_open());
 | 
						|
//  std::istream_iterator<std::string> it_expected(if_expected);
 | 
						|
//  ASSERT_TRUE(std::equal(it_result, std::istream_iterator<std::string>(), it_expected));
 | 
						|
//  std::remove(tmp_file);
 | 
						|
}
 | 
						|
 | 
						|
int main(int argc, char **argv)
 | 
						|
{
 | 
						|
  ::testing::InitGoogleTest(&argc,argv);
 | 
						|
  init_sql_factories();
 | 
						|
  return RUN_ALL_TESTS();
 | 
						|
}
 |