baidu palo

This commit is contained in:
cyongli
2017-08-11 17:51:21 +08:00
commit e2311f656e
1988 changed files with 586941 additions and 0 deletions

View File

@ -0,0 +1,600 @@
// Modifications copyright (C) 2017, Baidu.com, Inc.
// Copyright 2017 The Apache Software Foundation
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "exprs/arithmetic_expr.h"
#include "codegen/llvm_codegen.h"
#include "codegen/codegen_anyval.h"
#include "runtime/runtime_state.h"
using llvm::BasicBlock;
using llvm::Constant;
using llvm::ConstantFP;
using llvm::ConstantInt;
using llvm::Function;
using llvm::LLVMContext;
using llvm::PHINode;
using llvm::Value;
namespace palo {
Expr* ArithmeticExpr::from_thrift(const TExprNode& node) {
switch (node.opcode) {
case TExprOpcode::ADD:
return new AddExpr(node);
case TExprOpcode::SUBTRACT:
return new SubExpr(node);
case TExprOpcode::MULTIPLY:
return new MulExpr(node);
case TExprOpcode::DIVIDE:
case TExprOpcode::INT_DIVIDE:
return new DivExpr(node);
case TExprOpcode::MOD:
return new ModExpr(node);
case TExprOpcode::BITAND:
return new BitAndExpr(node);
case TExprOpcode::BITOR:
return new BitOrExpr(node);
case TExprOpcode::BITXOR:
return new BitXorExpr(node);
case TExprOpcode::BITNOT:
return new BitNotExpr(node);
default:
return NULL;
}
return NULL;
}
#define BINARY_OP_CHECK_ZERO_FN(TYPE, CLASS, FN, OP) \
TYPE CLASS::FN(ExprContext* context, TupleRow* row) { \
TYPE v1 = _children[0]->FN(context, row); \
if (v1.is_null) { \
return TYPE::null(); \
} \
TYPE v2 = _children[1]->FN(context, row); \
if (v2.is_null || v2.val == 0) { \
return TYPE::null(); \
} \
return TYPE(v1.val OP v2.val); \
}
#define BINARY_OP_FN(TYPE, CLASS, FN, OP) \
TYPE CLASS::FN(ExprContext* context, TupleRow* row) { \
TYPE v1 = _children[0]->FN(context, row); \
if (v1.is_null) { \
return TYPE::null(); \
} \
TYPE v2 = _children[1]->FN(context, row); \
if (v2.is_null) { \
return TYPE::null(); \
} \
return TYPE(v1.val OP v2.val); \
}
#define BINARY_ARITH_FNS(CLASS, OP) \
BINARY_OP_FN(TinyIntVal, CLASS, get_tiny_int_val, OP) \
BINARY_OP_FN(SmallIntVal, CLASS, get_small_int_val, OP) \
BINARY_OP_FN(IntVal, CLASS, get_int_val, OP) \
BINARY_OP_FN(BigIntVal, CLASS, get_big_int_val, OP) \
BINARY_OP_FN(LargeIntVal, CLASS, get_large_int_val, OP) \
BINARY_OP_FN(FloatVal, CLASS, get_float_val, OP) \
BINARY_OP_FN(DoubleVal, CLASS, get_double_val, OP) \
BINARY_ARITH_FNS(AddExpr, +)
BINARY_ARITH_FNS(SubExpr, -)
BINARY_ARITH_FNS(MulExpr, *)
#define BINARY_DIV_FNS() \
BINARY_OP_CHECK_ZERO_FN(TinyIntVal, DivExpr, get_tiny_int_val, /) \
BINARY_OP_CHECK_ZERO_FN(SmallIntVal, DivExpr, get_small_int_val, /) \
BINARY_OP_CHECK_ZERO_FN(IntVal, DivExpr, get_int_val, /) \
BINARY_OP_CHECK_ZERO_FN(BigIntVal, DivExpr, get_big_int_val, /) \
BINARY_OP_CHECK_ZERO_FN(LargeIntVal, DivExpr, get_large_int_val, /) \
BINARY_OP_CHECK_ZERO_FN(FloatVal, DivExpr, get_float_val, /) \
BINARY_OP_CHECK_ZERO_FN(DoubleVal, DivExpr, get_double_val, /) \
BINARY_DIV_FNS()
#define BINARY_MOD_FNS() \
BINARY_OP_CHECK_ZERO_FN(TinyIntVal, ModExpr, get_tiny_int_val, %) \
BINARY_OP_CHECK_ZERO_FN(SmallIntVal, ModExpr, get_small_int_val, %) \
BINARY_OP_CHECK_ZERO_FN(IntVal, ModExpr, get_int_val, %) \
BINARY_OP_CHECK_ZERO_FN(BigIntVal, ModExpr, get_big_int_val, %) \
BINARY_OP_CHECK_ZERO_FN(LargeIntVal, ModExpr, get_large_int_val, %) \
BINARY_MOD_FNS()
FloatVal ModExpr::get_float_val(ExprContext* context, TupleRow* row) {
FloatVal v1 = _children[0]->get_float_val(context, row);
if (v1.is_null) {
return FloatVal::null();
}
FloatVal v2 = _children[1]->get_float_val(context, row);
if (v2.is_null) {
return FloatVal::null();
}
return FloatVal(fmod(v1.val, v2.val));
}
DoubleVal ModExpr::get_double_val(ExprContext* context, TupleRow* row) {
DoubleVal v1 = _children[0]->get_double_val(context, row);
if (v1.is_null) {
return DoubleVal::null();
}
DoubleVal v2 = _children[1]->get_double_val(context, row);
if (v2.is_null) {
return DoubleVal::null();
}
return DoubleVal(fmod(v1.val, v2.val));
}
#define BINARY_BIT_FNS(CLASS, OP) \
BINARY_OP_FN(TinyIntVal, CLASS, get_tiny_int_val, OP) \
BINARY_OP_FN(SmallIntVal, CLASS, get_small_int_val, OP) \
BINARY_OP_FN(IntVal, CLASS, get_int_val, OP) \
BINARY_OP_FN(BigIntVal, CLASS, get_big_int_val, OP) \
BINARY_OP_FN(LargeIntVal, CLASS, get_large_int_val, OP) \
BINARY_BIT_FNS(BitAndExpr, &)
BINARY_BIT_FNS(BitOrExpr, |)
BINARY_BIT_FNS(BitXorExpr, ^)
#define BITNOT_OP_FN(TYPE, FN) \
TYPE BitNotExpr::FN(ExprContext* context, TupleRow* row) { \
TYPE v = _children[0]->FN(context, row); \
if (v.is_null) { \
return TYPE::null(); \
} \
return TYPE(~v.val); \
}
#define BITNOT_FNS() \
BITNOT_OP_FN(TinyIntVal, get_tiny_int_val) \
BITNOT_OP_FN(SmallIntVal, get_small_int_val) \
BITNOT_OP_FN(IntVal, get_int_val) \
BITNOT_OP_FN(BigIntVal, get_big_int_val) \
BITNOT_OP_FN(LargeIntVal, get_large_int_val) \
BITNOT_FNS()
// IR codegen for compound add predicates. Compound predicate has non trivial
// null handling as well as many branches so this is pretty complicated. The IR
// for x && y is:
//
// define i16 @Add(%"class.palo::ExprContext"* %context,
// %"class.palo::TupleRow"* %row) #20 {
// entry:
// %lhs_val = call { i8, i64 } @get_slot_ref(%"class.palo::ExprContext"* %context,
// %"class.palo::TupleRow"* %row)
// %0 = extractvalue { i8, i64 } %lhs_val, 0
// %lhs_is_null = trunc i8 %0 to i1
// br i1 %lhs_is_null, label %null, label %lhs_not_null
//
// lhs_not_null: ; preds = %entry
// %rhs_val = call { i8, i64 } @get_slot_ref(%"class.palo::ExprContext"* %context,
// %"class.palo::TupleRow"* %row)
// %1 = extractvalue { i8, i64 } %lhs_val, 0
// %rhs_is_null = trunc i8 %1 to i1
// br i1 %rhs_is_null, label %null, label %rhs_not_null
//
// rhs_not_null: ; preds = %entry
// %2 = extractvalue { i8, i64 } %lhs_val, 1
// %3 = extractvalue { i8, i64 } %rhs_val, 1
// %val = add i64 %2, %3
// br label %ret
//
// ret: ; preds = %not_null_block, %null_block
// %ret3 = phi i1 [ false, %null_block ], [ %4, %not_null_block ]
// %5 = zext i1 %ret3 to i16
// %6 = shl i16 %5, 8
// %7 = or i16 0, %6
// ret i16 %7
// }
Status ArithmeticExpr::codegen_binary_op(
RuntimeState* state, llvm::Function** fn, BinaryOpType op_type) {
LlvmCodeGen* codegen = NULL;
RETURN_IF_ERROR(state->get_codegen(&codegen));
Function* lhs_fn = NULL;
RETURN_IF_ERROR(_children[0]->get_codegend_compute_fn(state, &lhs_fn));
Function* rhs_fn = NULL;
RETURN_IF_ERROR(_children[1]->get_codegend_compute_fn(state, &rhs_fn));
// Function protocol
Value* args[2];
*fn = create_ir_function_prototype(codegen, "add", &args);
LLVMContext& cg_ctx = codegen->context();
LlvmCodeGen::LlvmBuilder builder(cg_ctx);
// Constant
Value* zero = ConstantInt::get(codegen->get_type(TYPE_TINYINT), 0);
Value* one = ConstantInt::get(codegen->get_type(TYPE_TINYINT), 1);
// Block
BasicBlock* entry_block = BasicBlock::Create(cg_ctx, "entry", *fn);
BasicBlock* lhs_not_null_block = BasicBlock::Create(cg_ctx, "lhs_not_null", *fn);
BasicBlock* rhs_not_null_block = NULL;
if ((op_type == BinaryOpType::DIV || op_type == BinaryOpType::MOD)) {
// && _type.type != TYPE_DOUBLE && _type.type != TYPE_FLOAT) {
rhs_not_null_block = BasicBlock::Create(cg_ctx, "rhs_not_null", *fn);
}
BasicBlock* compute_block = BasicBlock::Create(cg_ctx, "compute", *fn);
BasicBlock* null_block = BasicBlock::Create(cg_ctx, "null", *fn);
BasicBlock* ret_block = BasicBlock::Create(cg_ctx, "ret", *fn);
// entry block
builder.SetInsertPoint(entry_block);
CodegenAnyVal lhs_val = CodegenAnyVal::create_call_wrapped(
codegen, &builder, _children[0]->type(), lhs_fn, args, "lhs_val");
// if (v1.is_null) return null;
Value* lhs_is_null = lhs_val.get_is_null();
builder.CreateCondBr(lhs_is_null, null_block, lhs_not_null_block);
// lhs_not_null_block
builder.SetInsertPoint(lhs_not_null_block);
CodegenAnyVal rhs_val = CodegenAnyVal::create_call_wrapped(
codegen, &builder, _children[1]->type(), rhs_fn, args, "lhs_val");
Value* rhs_is_null = rhs_val.get_is_null();
// if (v2.is_null) return null;
Value* rhs_llvm_val = NULL;
if ((op_type == BinaryOpType::DIV || op_type == BinaryOpType::MOD)) {
// && _type.type != TYPE_DOUBLE && _type.type != TYPE_FLOAT) {
builder.CreateCondBr(rhs_is_null, null_block, rhs_not_null_block);
// if (v2.val == 0)
builder.SetInsertPoint(rhs_not_null_block);
rhs_llvm_val = rhs_val.get_val();
Value* rhs_zero = NULL;
Value* rhs_is_zero = NULL;
if (_type.type == TYPE_DOUBLE || _type.type == TYPE_FLOAT) {
rhs_zero = ConstantFP::get(rhs_llvm_val->getType(), 0.0);
rhs_is_zero = builder.CreateFCmpOEQ(rhs_llvm_val, rhs_zero, "rhs_is_zero");
} else {
rhs_zero = ConstantInt::get(rhs_llvm_val->getType(), 0);
rhs_is_zero = builder.CreateICmpEQ(rhs_llvm_val, rhs_zero, "rhs_is_zero");
}
builder.CreateCondBr(rhs_is_zero, null_block, compute_block);
} else {
builder.CreateCondBr(rhs_is_null, null_block, compute_block);
}
// compute_block
builder.SetInsertPoint(compute_block);
Value* val = NULL;
switch (op_type) {
case ADD: {
switch (_type.type) {
case TYPE_TINYINT:
case TYPE_SMALLINT:
case TYPE_INT:
case TYPE_BIGINT:
case TYPE_LARGEINT:
val = builder.CreateAdd(lhs_val.get_val(), rhs_val.get_val(), "val");
break;
case TYPE_FLOAT:
case TYPE_DOUBLE:
val = builder.CreateFAdd(lhs_val.get_val(), rhs_val.get_val(), "val");
break;
default:
return Status("Unk");
}
break;
}
case SUB: {
switch (_type.type) {
case TYPE_TINYINT:
case TYPE_SMALLINT:
case TYPE_INT:
case TYPE_BIGINT:
case TYPE_LARGEINT:
val = builder.CreateSub(lhs_val.get_val(), rhs_val.get_val(), "val");
break;
case TYPE_FLOAT:
case TYPE_DOUBLE:
val = builder.CreateFSub(lhs_val.get_val(), rhs_val.get_val(), "val");
break;
default:
return Status("Unk");
}
break;
}
case MUL: {
switch (_type.type) {
case TYPE_TINYINT:
case TYPE_SMALLINT:
case TYPE_INT:
case TYPE_BIGINT:
case TYPE_LARGEINT:
val = builder.CreateMul(lhs_val.get_val(), rhs_val.get_val(), "val");
break;
case TYPE_FLOAT:
case TYPE_DOUBLE:
val = builder.CreateFMul(lhs_val.get_val(), rhs_val.get_val(), "val");
break;
default:
return Status("Unk");
}
break;
}
case DIV: {
switch (_type.type) {
case TYPE_TINYINT:
case TYPE_SMALLINT:
case TYPE_INT:
case TYPE_BIGINT:
case TYPE_LARGEINT:
val = builder.CreateSDiv(lhs_val.get_val(), rhs_llvm_val, "val");
break;
case TYPE_FLOAT:
case TYPE_DOUBLE:
val = builder.CreateFDiv(lhs_val.get_val(), rhs_val.get_val(), "val");
break;
default:
return Status("Unk");
}
break;
}
case MOD: {
switch (_type.type) {
case TYPE_TINYINT:
case TYPE_SMALLINT:
case TYPE_INT:
case TYPE_BIGINT:
case TYPE_LARGEINT:
val = builder.CreateSRem(lhs_val.get_val(), rhs_llvm_val, "val");
break;
case TYPE_FLOAT:
case TYPE_DOUBLE:
val = builder.CreateFRem(lhs_val.get_val(), rhs_val.get_val(), "val");
break;
default:
return Status("Unk");
}
break;
}
case BIT_AND: {
switch (_type.type) {
case TYPE_TINYINT:
case TYPE_SMALLINT:
case TYPE_INT:
case TYPE_BIGINT:
case TYPE_LARGEINT:
val = builder.CreateAnd(lhs_val.get_val(), rhs_val.get_val(), "val");
break;
default:
return Status("Unk");
}
break;
}
case BIT_OR: {
switch (_type.type) {
case TYPE_TINYINT:
case TYPE_SMALLINT:
case TYPE_INT:
case TYPE_BIGINT:
case TYPE_LARGEINT:
val = builder.CreateOr(lhs_val.get_val(), rhs_val.get_val(), "val");
break;
default:
return Status("Unk");
}
break;
}
case BIT_XOR: {
switch (_type.type) {
case TYPE_TINYINT:
case TYPE_SMALLINT:
case TYPE_INT:
case TYPE_BIGINT:
case TYPE_LARGEINT:
val = builder.CreateXor(lhs_val.get_val(), rhs_val.get_val(), "val");
break;
default:
return Status("Unk");
}
break;
}
default:
return Status("Unknow operato");
}
builder.CreateBr(ret_block);
// null block
builder.SetInsertPoint(null_block);
builder.CreateBr(ret_block);
// ret block
builder.SetInsertPoint(ret_block);
PHINode* is_null_phi = builder.CreatePHI(codegen->tinyint_type(), 2, "is_null_phi");
is_null_phi->addIncoming(one, null_block);
is_null_phi->addIncoming(zero, compute_block);
PHINode* val_phi = builder.CreatePHI(val->getType(), 2, "val_phi");
Value* null = Constant::getNullValue(val->getType());
val_phi->addIncoming(null, null_block);
val_phi->addIncoming(val, compute_block);
CodegenAnyVal result = CodegenAnyVal::get_non_null_val(
codegen, &builder, type(), "result");
result.set_is_null(is_null_phi);
result.set_val(val_phi);
builder.CreateRet(result.value());
*fn = codegen->finalize_function(*fn);
return Status::OK;
}
Status AddExpr::get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn) {
if (_ir_compute_fn != NULL) {
*fn = _ir_compute_fn;
return Status::OK;
}
RETURN_IF_ERROR(codegen_binary_op(state, fn , BinaryOpType::ADD));
_ir_compute_fn = *fn;
return Status::OK;
}
Status SubExpr::get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn) {
if (_ir_compute_fn != NULL) {
*fn = _ir_compute_fn;
return Status::OK;
}
RETURN_IF_ERROR(codegen_binary_op(state, fn , BinaryOpType::SUB));
_ir_compute_fn = *fn;
return Status::OK;
}
Status MulExpr::get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn) {
if (_ir_compute_fn != NULL) {
*fn = _ir_compute_fn;
return Status::OK;
}
RETURN_IF_ERROR(codegen_binary_op(state, fn , BinaryOpType::MUL));
_ir_compute_fn = *fn;
return Status::OK;
}
Status DivExpr::get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn) {
if (_ir_compute_fn != NULL) {
*fn = _ir_compute_fn;
return Status::OK;
}
RETURN_IF_ERROR(codegen_binary_op(state, fn , BinaryOpType::DIV));
_ir_compute_fn = *fn;
return Status::OK;
}
Status ModExpr::get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn) {
if (_ir_compute_fn != NULL) {
*fn = _ir_compute_fn;
return Status::OK;
}
RETURN_IF_ERROR(codegen_binary_op(state, fn , BinaryOpType::MOD));
_ir_compute_fn = *fn;
return Status::OK;
}
Status BitAndExpr::get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn) {
if (_ir_compute_fn != NULL) {
*fn = _ir_compute_fn;
return Status::OK;
}
RETURN_IF_ERROR(codegen_binary_op(state, fn , BinaryOpType::BIT_AND));
_ir_compute_fn = *fn;
return Status::OK;
}
Status BitOrExpr::get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn) {
if (_ir_compute_fn != NULL) {
*fn = _ir_compute_fn;
return Status::OK;
}
RETURN_IF_ERROR(codegen_binary_op(state, fn , BinaryOpType::BIT_OR));
_ir_compute_fn = *fn;
return Status::OK;
}
Status BitXorExpr::get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn) {
if (_ir_compute_fn != NULL) {
*fn = _ir_compute_fn;
return Status::OK;
}
RETURN_IF_ERROR(codegen_binary_op(state, fn , BinaryOpType::BIT_XOR));
_ir_compute_fn = *fn;
return Status::OK;
}
// IR codegen for compound add predicates. Compound predicate has non trivial
// null handling as well as many branches so this is pretty complicated. The IR
// for x && y is:
//
// define i16 @Add(%"class.palo::ExprContext"* %context,
// %"class.palo::TupleRow"* %row) #20 {
// entry:
// %lhs_val = call { i8, i64 } @get_slot_ref(%"class.palo::ExprContext"* %context,
// %"class.palo::TupleRow"* %row)
// %0 = extractvalue { i8, i64 } %lhs_val, 0
// %lhs_is_null = trunc i8 %0 to i1
// br i1 %lhs_is_null, label %null, label %lhs_not_null
//
// ret: ; preds = %not_null_block, %null_block
// %ret3 = phi i1 [ false, %null_block ], [ %4, %not_null_block ]
// %5 = zext i1 %ret3 to i16
// %6 = shl i16 %5, 8
// %7 = or i16 0, %6
// ret i16 %7
// }
Status BitNotExpr::get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn) {
if (_ir_compute_fn != NULL) {
*fn = _ir_compute_fn;
return Status::OK;
}
LlvmCodeGen* codegen = NULL;
RETURN_IF_ERROR(state->get_codegen(&codegen));
Function* child_fn = NULL;
RETURN_IF_ERROR(_children[0]->get_codegend_compute_fn(state, &child_fn));
// Function protocol
Value* args[2];
*fn = create_ir_function_prototype(codegen, "add", &args);
LLVMContext& cg_ctx = codegen->context();
LlvmCodeGen::LlvmBuilder builder(cg_ctx);
// Constant
Value* zero = ConstantInt::get(codegen->get_type(TYPE_TINYINT), 0);
Value* one = ConstantInt::get(codegen->get_type(TYPE_TINYINT), 1);
// Block
BasicBlock* entry_block = BasicBlock::Create(cg_ctx, "entry", *fn);
BasicBlock* compute_block = BasicBlock::Create(cg_ctx, "compute", *fn);
BasicBlock* ret_block = BasicBlock::Create(cg_ctx, "ret", *fn);
// entry block
builder.SetInsertPoint(entry_block);
CodegenAnyVal child_val = CodegenAnyVal::create_call_wrapped(
codegen, &builder, _children[0]->type(), child_fn, args, "child_val");
// if (v1.is_null) return null;
Value* child_is_null = child_val.get_is_null();
builder.CreateCondBr(child_is_null, ret_block, compute_block);
// compute_block
builder.SetInsertPoint(compute_block);
Value* val = builder.CreateNot(child_val.get_val(), "val");
builder.CreateBr(ret_block);
// ret block
builder.SetInsertPoint(ret_block);
PHINode* is_null_phi = builder.CreatePHI(codegen->tinyint_type(), 2, "is_null_phi");
is_null_phi->addIncoming(one, entry_block);
is_null_phi->addIncoming(zero, compute_block);
PHINode* val_phi = builder.CreatePHI(val->getType(), 2, "val_phi");
Value* null = Constant::getNullValue(val->getType());
val_phi->addIncoming(null, entry_block);
val_phi->addIncoming(val, compute_block);
CodegenAnyVal result = CodegenAnyVal::get_non_null_val(codegen, &builder, type(), "result");
result.set_is_null(is_null_phi);
result.set_val(val_phi);
builder.CreateRet(result.value());
*fn = codegen->finalize_function(*fn);
_ir_compute_fn = *fn;
return Status::OK;
}
}