Case when expression modification

This commit is contained in:
qubin-ben
2023-05-22 02:41:44 +00:00
committed by ob-robot
parent c4fafdf6eb
commit da418b0c5a

View File

@ -240,75 +240,94 @@ int ObExprCase::eval_case_batch(const ObExpr &expr,
int64_t loop = (has_else) ? expr.arg_cnt_ - 1 : expr.arg_cnt_;
bool match_when = false;
ObDatum *results = expr.locate_batch_datums(ctx);
LOG_DEBUG("eval_case_batch", K(expr.arg_cnt_));
if (OB_ISNULL(results)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("results frame is not init", K(ret));
} else {
ObBitVector &eval_flags = expr.get_evaluated_flags(ctx);
//my_skip 用于记录上一个skip的信息, 获取新的skip之后将其转为对then求值的skip
ObBitVector &my_skip = expr.get_pvt_skip(ctx);
ObBitVector *last_skip = nullptr;
my_skip.deep_copy(skip, batch_size);
ObBitVector *case_when_match = nullptr;
ObBitVector *case_not_match = nullptr;
void * data = nullptr;
void * data1 = nullptr;
ObEvalCtx::TempAllocGuard alloc_guard(ctx);
if (OB_ISNULL(data = alloc_guard.get_allocator().alloc(ObBitVector::memory_size(batch_size)))) {
ret = OB_ALLOCATE_MEMORY_FAILED;
LOG_WARN("failed to alloc memory for last_skip", K(ret));
LOG_WARN("failed to alloc memory for case_when_match", K(ret));
} else if (OB_ISNULL(data1 = alloc_guard.get_allocator().alloc(ObBitVector::memory_size(batch_size)))) {
ret = OB_ALLOCATE_MEMORY_FAILED;
LOG_WARN("failed to alloc memory for case_when_match", K(ret));
} else {
last_skip = to_bit_vector(data);
//last_skip = eval_flags | skip
last_skip->bit_calculate(skip, eval_flags, batch_size,
case_when_match = to_bit_vector(data);
case_not_match = to_bit_vector(data1);
case_when_match->reset(batch_size);
case_not_match->reset(batch_size);
//case_when_match = eval_flags | skip
case_when_match->bit_calculate(skip, eval_flags, batch_size,
[](const uint64_t l, const uint64_t r) { return (l | r); });
case_not_match->bit_calculate(skip, eval_flags, batch_size,
[](const uint64_t l, const uint64_t r) { return (l | r); });
}
//eval when_datum according to last_skip, and set last_skip
//eval then datum according to my_skip = (my_skip SOR last_skip)
// E.G
// SELECT CASE WHEN expr1 THEN expr2 WHEN expr3 THEN expr4 ... ELSE exprN END
// the logic is
// 1. calc when branch, save result in when_datums and use match_when flag
// to mark which rows are matched in when branch and these rows should be
// calculated in then branch
// 2. calc then branch, put matching result(then_datums) into output datums
// (results)
// REPEAT 1. and 2.
// ...
// LAST.
// calc else branch and put matching result(then_datums) into output datums
for (int64_t expr_idx = 0; OB_SUCC(ret) && expr_idx < loop; expr_idx += 2) {
if (OB_FAIL(expr.args_[expr_idx]->eval_batch(ctx, *last_skip, batch_size))) {
if (OB_FAIL(expr.args_[expr_idx]->eval_batch(ctx, *case_when_match, batch_size))) {
LOG_WARN("failed to eval batch", K(ret), K(expr_idx));
} else {
ObDatumVector when_datums = expr.args_[expr_idx]->locate_expr_datumvector(ctx);
//first eval when datums
for (int64_t j = 0; OB_SUCC(ret) && j < batch_size; ++j) {
if (last_skip->at(j)) {
if (case_when_match->at(j)) {
continue;
}
if (OB_FAIL(check_is_match(*when_datums.at(j), match_when))) {
LOG_WARN("check is when expr match failed", K(ret), K(j));
} else if (match_when) {
last_skip->set(j);
case_when_match->set(j);
} else {
// not match, mark case_not_match to stop calculating then branch
case_not_match->set(j);
}
}
//now set the my_skip to eval then datums, my_skip = my_skip SAME OR last_skip
my_skip.bit_calculate(my_skip,
*last_skip,
batch_size,
[](const uint64_t l, const uint64_t r) { return ~(l ^ r); });
//now eval then datums
if (OB_FAIL(expr.args_[expr_idx + 1]->eval_batch(ctx, my_skip, batch_size))) {
if (OB_FAIL(ret)) {
} else if (OB_FAIL(expr.args_[expr_idx + 1]->eval_batch(ctx, *case_not_match, batch_size))) {
LOG_WARN("failed to eval batch", K(ret), K(expr_idx + 1));
} else {
ObDatumVector then_datums = expr.args_[expr_idx + 1]->locate_expr_datumvector(ctx);
for (int64_t j = 0; OB_SUCC(ret) && j < batch_size; ++j) {
if (my_skip.at(j)) {
if (case_not_match->at(j)) {
continue;
}
results[j].set_datum(*then_datums.at(j));
eval_flags.set(j);
}
//we need save the status had evaluated, so use my_skip to backup
my_skip.deep_copy(*last_skip, batch_size);
// rows matched in this round should not match in next round, therefor,
// copy last round matched rows flag(case_when_match) into case_not_match
case_not_match->deep_copy(*case_when_match, batch_size);
}
}
}
//now set the result of the rest, according to last_skip
//now set the result of the rest, skip rows already matched (case_when_match)
if (OB_SUCC(ret)) {
if (has_else) {
if (OB_FAIL(expr.args_[expr.arg_cnt_ - 1]->eval_batch(ctx, *last_skip, batch_size))) {
if (OB_FAIL(expr.args_[expr.arg_cnt_ - 1]->eval_batch(ctx, *case_when_match, batch_size))) {
LOG_WARN("failed to eval batch", K(ret));
} else {
ObDatumVector else_datums = expr.args_[expr.arg_cnt_ - 1]->locate_expr_datumvector(ctx);
for (int64_t j = 0; OB_SUCC(ret) && j < batch_size; ++j) {
if (last_skip->at(j)) {
if (case_when_match->at(j)) {
continue;
}
results[j].set_datum(*else_datums.at(j));
@ -317,7 +336,7 @@ int ObExprCase::eval_case_batch(const ObExpr &expr,
}
} else {
for (int64_t j = 0; OB_SUCC(ret) && j < batch_size; ++j) {
if (last_skip->at(j)) {
if (case_when_match->at(j)) {
continue;
}
results[j].set_null();