wangzelin.wzl 93a1074b0c patch 4.0
2022-10-24 17:57:12 +08:00

456 lines
14 KiB
C++

/**
* Copyright (c) 2021 OceanBase
* OceanBase CE is licensed under Mulan PubL v2.
* You can use this software according to the terms and conditions of the Mulan PubL v2.
* You may obtain a copy of Mulan PubL v2 at:
* http://license.coscl.org.cn/MulanPubL-2.0
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PubL v2 for more details.
*/
#include <gtest/gtest.h>
#include "sql/test_sql_utils.h"
#include "lib/utility/ob_test_util.h"
#include "sql/resolver/expr/ob_raw_expr_resolver_impl.h"
#include "sql/resolver/expr/ob_raw_expr_util.h"
#include "sql/resolver/expr/ob_raw_expr_print_visitor.h"
#include "sql/ob_sql_init.h"
#include "lib/json/ob_json_print_utils.h"
#include <fstream>
// files needed by LLVM
#ifdef ID
#undef ID
#endif
#include "llvm/DerivedTypes.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/JIT.h"
#include "llvm/LLVMContext.h"
#include "llvm/Module.h"
#include "llvm/PassManager.h"
#include "llvm/Analysis/Verifier.h"
#include "llvm/Analysis/Passes.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Support/IRBuilder.h"
#include "llvm/Support/TargetSelect.h"
#include <cstdio>
#include <string>
#include <map>
#include <vector>
using namespace llvm;
using namespace oceanbase::common;
using namespace oceanbase::sql;
static FunctionPassManager *TheFPM;
static Module *TheModule;
////static std::map<std::string, Value*> NamedValues;
static IRBuilder<> Builder(getGlobalContext());
static ExecutionEngine *TheExecutionEngine;
static int64_t get_usec()
{
struct timeval time_val;
gettimeofday(&time_val, NULL);
return time_val.tv_sec*1000000 + time_val.tv_usec;
}
class ExprAST;
class TestLLVM: public ::testing::Test
{
public:
TestLLVM();
virtual ~TestLLVM();
virtual void SetUp();
virtual void TearDown();
private:
DISALLOW_COPY_AND_ASSIGN(TestLLVM);
protected:
// function members
int resolve(const char* expr, const char *&json_expr);
int convert(ObRawExpr *raw_expr, ExprAST *&ast);
};
TestLLVM::TestLLVM()
{
}
TestLLVM::~TestLLVM()
{
}
void TestLLVM::SetUp()
{
}
void TestLLVM::TearDown()
{
}
/// ExprAST - Base class for all expression nodes.
class ExprAST {
public:
virtual ~ExprAST() {}
virtual Value *Codegen() = 0;
};
/// NumberExprAST - Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double Val;
public:
NumberExprAST(double val) : Val(val) {}
virtual Value *Codegen();
};
Value *NumberExprAST::Codegen() {
return ConstantFP::get(getGlobalContext(), APFloat(Val));
}
ExprAST *Error(const char *Str) { fprintf(stderr, "Error: %s\n", Str);return 0;}
Value *ErrorV(const char *Str) { Error(Str); return 0; }
///// VariableExprAST - Expression class for referencing a variable, like "a".
//class VariableExprAST : public ExprAST {
// std::string Name;
//public:
// VariableExprAST(const std::string &name) : Name(name) {}
// virtual Value *Codegen();
//};
/// BinaryExprAST - Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
char Op; // ObItemType
ExprAST *LHS, *RHS;
public:
BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs)
: Op(op), LHS(lhs), RHS(rhs) {}
void add_child_ast(ExprAST *child);
virtual Value *Codegen();
};
Value *BinaryExprAST::Codegen() {
Value *L = LHS->Codegen();
Value *R = RHS->Codegen();
if (L == 0 || R == 0) return 0;
// printf("binary op %d\n", Op);
switch (Op) {
case T_OP_ADD: return Builder.CreateFAdd(L, R, "addtmp");
case T_OP_MINUS: return Builder.CreateFSub(L, R, "subtmp");
case T_OP_MUL: return Builder.CreateFMul(L, R, "multmp");
case T_OP_LT:
L = Builder.CreateFCmpULT(L, R, "cmptmp");
// Convert bool 0/1 to double 0.0 or 1.0
return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
"booltmp");
default: return ErrorV("invalid binary operator");
}
}
void BinaryExprAST::add_child_ast(ExprAST *child)
{
if (LHS) {
RHS = child;
} else {
LHS = child;
}
}
/// PrototypeAST - This class represents the "prototype" for a function,
/// which captures its name, and its argument names (thus implicitly the number
/// of arguments the function takes).
class PrototypeAST {
std::string Name;
std::vector<std::string> Args;
std::map<std::string, Value*> NamedValues;
public:
PrototypeAST(const std::string &name, const std::vector<std::string> &args)
: Name(name), Args(args) {}
Function *Codegen();
};
Function *PrototypeAST::Codegen() {
// Make the function type: double(double,double) etc.
std::vector<Type*> Doubles(Args.size(),
Type::getDoubleTy(getGlobalContext()));
FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()),
Doubles, false);
Function *F = Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
// If F conflicted, there was already something named 'Name'. If it has a
// body, don't allow redefinition or reextern.
if (F->getName() != Name) {
// Delete the one we just made and get the existing one.
F->eraseFromParent();
F = TheModule->getFunction(Name);
// If F already has a body, reject this.
if (!F->empty()) {
printf("redefinition of function");
return 0;
}
// If F took a different number of args, reject.
if (F->arg_size() != Args.size()) {
printf("redefinition of function with different # args");
return 0;
}
}
// Set names for all arguments.
unsigned Idx = 0;
for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
++AI, ++Idx) {
AI->setName(Args[Idx]);
// Add arguments to variable symbol table.
NamedValues[Args[Idx]] = AI;
}
return F;
}
/// FunctionAST - This class represents a function definition itself.
class FunctionAST {
PrototypeAST *Proto;
ExprAST *Body;
public:
FunctionAST(PrototypeAST *proto, ExprAST *body)
: Proto(proto), Body(body) {}
Function *Codegen();
};
Function *FunctionAST::Codegen() {
// NamedValues.clear();
Function *TheFunction = Proto->Codegen();
if (TheFunction == 0)
return 0;
// Create a new basic block to start insertion into.
BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
Builder.SetInsertPoint(BB);
if (Value *RetVal = Body->Codegen()) {
// Finish off the function.
Builder.CreateRet(RetVal);
// Validate the generated code, checking for consistency.
verifyFunction(*TheFunction);
// Optimize the function.
TheFPM->run(*TheFunction);
return TheFunction;
}
// Error reading body, remove function.
TheFunction->eraseFromParent();
return 0;
}
int TestLLVM::resolve(const char* expr, const char *&json_expr)
{
int ret = OB_SUCCESS;
ObArray<ObRawExpr*> expr_store;
ObArray<ObQualifiedName> columns;
ObArray<ObVarInfo> sys_vars;
ObArray<ObSubQueryInfo> sub_query_info;
const char* expr_str = expr;
ObArenaAllocator allocator(ObModIds::TEST);
ObTimeZoneInfo tz_info;
ObNameCaseMode case_mode = OB_NAME_CASE_INVALID;
ObRawExprFactory expr_factory(allocator);
ObExprResolveContext ctx(expr_factory, &tz_info, case_mode);
ctx.connection_charset_ = ObCharset::get_default_charset();
ctx.dest_collation_ = ObCharset::get_default_collation(ctx.connection_charset_);
ctx.is_extract_param_type_ = false;
ObRawExpr *raw_expr = NULL;
ObArray<ObAggFunRawExpr*> aggr_exprs;
ObArray<ObWinFunRawExpr*> win_exprs;
ObArray<ObUDFInfo> udf_info;
if (OB_FAIL(ObRawExprUtils::make_raw_expr_from_str(expr_str,
strlen(expr_str),
ctx,
raw_expr,
columns,
sys_vars,
&sub_query_info,
aggr_exprs,
win_exprs,
udf_info)))
{
printf("error");
}
_OB_LOG(DEBUG, "================================================================");
_OB_LOG(DEBUG, "%s", expr);
_OB_LOG(DEBUG, "%s", CSJ(raw_expr));
if (OB_FAIL(raw_expr->extract_info())) {
printf("error");
}
//OK(raw_expr->deduce_type());
json_expr = CSJ(raw_expr);
ExprAST *ast = NULL;
if (OB_FAIL(convert(raw_expr, ast))) {
printf("fail to convert");
} else {
InitializeNativeTarget();
LLVMContext &Context = getGlobalContext();
// Make the module, which holds all the code.
TheModule = new Module("my cool jit", Context);
// Create the JIT. This takes ownership of the module.
std::string ErrStr;
TheExecutionEngine = EngineBuilder(TheModule).setErrorStr(&ErrStr).create();
if (!TheExecutionEngine) {
printf("Could not create ExecutionEngine\n");
std::cout << "ErrStr : " << ErrStr << std::endl;
OB_ASSERT(0);
} else {
FunctionPassManager OurFPM(TheModule);
// Set up the optimizer pipeline. Start with registering info about how the
// target lays out data structures.
OurFPM.add(new TargetData(*TheExecutionEngine->getTargetData()));
// Provide basic AliasAnalysis support for GVN.
OurFPM.add(createBasicAliasAnalysisPass());
// Do simple "peephole" optimizations and bit-twiddling optzns.
OurFPM.add(createInsstructionCombiningPass());
// Reassociate expressions.
OurFPM.add(createReassociatePass());
// Eliminate Common SubExpressions.
OurFPM.add(createGVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
OurFPM.add(createCFGSimplificationPass());
OurFPM.doInitialization();
// Set the global so the code gen can use this.
TheFPM = &OurFPM;
PrototypeAST *Proto = new PrototypeAST("", std::vector<std::string>());
FunctionAST *func = new FunctionAST(Proto, ast);
if (!func) {
} else if (Function *LF = func->Codegen()) {
// JIT the function, returning a function pointer.
int64_t i, j, m, n;
i = get_usec();
void *FPtr = TheExecutionEngine->getPointerToFunction(LF);
j = get_usec();
printf("JIT in %ld us\n", j - i);
// Cast it to the right type (takes no arguments, returns a double) so we
// can call it as a native function.
double (*FP)() = (double (*)())(intptr_t)FPtr;
m = get_usec();
for (int64_t k = 0; k < 10000000; k++) {
FP();
}
n = get_usec();
printf("Evaluated to %f\n in %ld us\n", FP(), n-m);
}
}
}
return ret;
}
int TestLLVM::convert(ObRawExpr *raw_expr, ExprAST *&ast)
{
int ret = OB_SUCCESS;
switch (raw_expr->get_expr_class()) {
case ObRawExpr::EXPR_CONST: {
ObConstRawExpr *const_expr = static_cast<ObConstRawExpr*> (raw_expr);
ast = new NumberExprAST(static_cast<double> (const_expr->get_value().get_int()));
}
break;
case ObRawExpr::EXPR_OPERATOR: {
ast = new BinaryExprAST(raw_expr->get_expr_type(), NULL, NULL);
}
break;
default:
OB_ASSERT(0);
break;
}
if (raw_expr->get_expr_class() == ObRawExpr::EXPR_CONST) {
// do nothing
} else {
ObOpRawExpr *op_expr = static_cast<ObOpRawExpr *> (raw_expr);
BinaryExprAST *binary_ast = static_cast<BinaryExprAST *> (ast);
for (int64_t i = 0; OB_SUCC(ret) && i < op_expr->get_param_count(); i++) {
ExprAST *new_ast = NULL;
if (OB_FAIL(convert(op_expr->get_param_exprs().at(i), new_ast))) {
printf("fail to convert");
} else {
binary_ast->add_child_ast(new_ast);
}
}
}
return ret;
}
TEST_F(TestLLVM, all)
{
const char* json_expr = NULL;
int ret = OB_SUCCESS;
if (OB_SUCCESS != (ret = resolve("1+2", json_expr))) {
printf("fail to resolve\n");
} else if (OB_SUCCESS != (ret = resolve("2-1", json_expr))) {
printf("fail to resolve\n");
} else if (OB_SUCCESS != (ret = resolve("1345435+ 1232132", json_expr))) {
printf("fail to resolve\n");
} else if (OB_SUCCESS != (ret = resolve("6/3", json_expr))) {
}
// static const char* test_file = "./expr/test_raw_expr_resolver.test";
// static const char* tmp_file = "./expr/test_raw_expr_resolver.tmp";
// static const char* result_file = "./expr/test_raw_expr_resolver.result";
//
// std::ifstream if_tests(test_file);
// ASSERT_TRUE(if_tests.is_open());
// std::string line;
// const char* json_expr = NULL;
// std::ofstream of_result(tmp_file);
// ASSERT_TRUE(of_result.is_open());
// int64_t case_id = 0;
// while (std::getline(if_tests, line)) {
// of_result << '[' << case_id++ << "] " << line << std::endl;
// resolve(line.c_str(), json_expr);
// of_result << json_expr << std::endl;
// }
// of_result.close();
// // verify results
// fprintf(stderr, "If tests failed, use `diff %s %s' to see the differences. \n", result_file, tmp_file);
// std::ifstream if_result(tmp_file);
// ASSERT_TRUE(if_result.is_open());
// std::istream_iterator<std::string> it_result(if_result);
// std::ifstream if_expected(result_file);
// ASSERT_TRUE(if_expected.is_open());
// std::istream_iterator<std::string> it_expected(if_expected);
// ASSERT_TRUE(std::equal(it_result, std::istream_iterator<std::string>(), it_expected));
// std::remove(tmp_file);
}
int main(int argc, char **argv)
{
::testing::InitGoogleTest(&argc,argv);
init_sql_factories();
return RUN_ALL_TESTS();
}