[opt](lambda) let lambda expression support refer outer slot (#45186)

This commit is contained in:
shee
2024-12-11 18:55:49 +08:00
committed by GitHub
parent 6d4e541976
commit fb407f2e94
12 changed files with 327 additions and 52 deletions

View File

@ -17,7 +17,7 @@
#pragma once
#include <fmt/core.h>
#include <runtime/runtime_state.h>
#include "common/status.h"
#include "vec/core/block.h"
@ -31,9 +31,16 @@ public:
virtual std::string get_name() const = 0;
virtual doris::Status prepare(RuntimeState* state) {
batch_size = state->batch_size();
return Status::OK();
}
virtual doris::Status execute(VExprContext* context, doris::vectorized::Block* block,
int* result_column_id, const DataTypePtr& result_type,
const VExprSPtrs& children) = 0;
int batch_size;
};
using LambdaFunctionPtr = std::shared_ptr<LambdaFunction>;

View File

@ -15,6 +15,10 @@
// specific language governing permissions and limitations
// under the License.
#include <vec/data_types/data_type_number.h>
#include <vec/exprs/vcolumn_ref.h>
#include <vec/exprs/vslot_ref.h>
#include <memory>
#include <string>
#include <utility>
@ -47,6 +51,26 @@ class VExprContext;
namespace doris::vectorized {
// extend a block with all required parameters
struct LambdaArgs {
// the lambda function need the column ids of all the slots
std::vector<int> output_slot_ref_indexs;
// which line is extended to the original block
int64_t current_row_idx = 0;
// when a block is filled, the array may be truncated, recording where it was truncated
int64_t current_offset_in_array = 0;
// the beginning position of the array
size_t array_start = 0;
// the size of the array
int64_t cur_size = 0;
// offset of column array
const ColumnArray::Offsets64* offsets_ptr = nullptr;
// expend data of repeat times
int current_repeat_times = 0;
// whether the current row of the original block has been extended
bool current_row_eos = false;
};
class ArrayMapFunction : public LambdaFunction {
ENABLE_FACTORY_CREATOR(ArrayMapFunction);
@ -62,8 +86,33 @@ public:
doris::Status execute(VExprContext* context, doris::vectorized::Block* block,
int* result_column_id, const DataTypePtr& result_type,
const VExprSPtrs& children) override {
///* array_map(lambda,arg1,arg2,.....) *///
LambdaArgs args;
// collect used slot ref in lambda function body
_collect_slot_ref_column_id(children[0], args);
int gap = 0;
if (!args.output_slot_ref_indexs.empty()) {
auto max_id = std::max_element(args.output_slot_ref_indexs.begin(),
args.output_slot_ref_indexs.end());
gap = *max_id + 1;
_set_column_ref_column_id(children[0], gap);
}
std::vector<std::string> names(gap);
DataTypes data_types(gap);
for (int i = 0; i < gap; ++i) {
if (_contains_column_id(args, i)) {
names[i] = block->get_by_position(i).name;
data_types[i] = block->get_by_position(i).type;
} else {
// padding some mock data
names[i] = "temp";
data_types[i] = std::make_shared<DataTypeUInt8>();
}
}
///* array_map(lambda,arg1,arg2,.....) *///
//1. child[1:end]->execute(src_block)
doris::vectorized::ColumnNumbers arguments(children.size() - 1);
for (int i = 1; i < children.size(); ++i) {
@ -82,14 +131,13 @@ public:
MutableColumnPtr array_column_offset;
int nested_array_column_rows = 0;
ColumnPtr first_array_offsets = nullptr;
//2. get the result column from executed expr, and the needed is nested column of array
Block lambda_block;
std::vector<ColumnPtr> lambda_datas(arguments.size());
for (int i = 0; i < arguments.size(); ++i) {
const auto& array_column_type_name = block->get_by_position(arguments[i]);
auto column_array = array_column_type_name.column->convert_to_full_column_if_const();
auto type_array = array_column_type_name.type;
if (type_array->is_nullable()) {
// get the nullmap of nullable column
const auto& column_array_nullmap =
@ -118,6 +166,7 @@ public:
auto& off_data = assert_cast<const ColumnArray::ColumnOffsets&>(
col_array.get_offsets_column());
array_column_offset = off_data.clone_resized(col_array.get_offsets_column().size());
args.offsets_ptr = &col_array.get_offsets();
} else {
// select array_map((x,y)->x+y,c_array1,[0,1,2,3]) from array_test2;
// c_array1: [0,1,2,3,4,5,6,7,8,9]
@ -136,57 +185,164 @@ public:
nested_array_column_rows, i + 1, col_array.get_data_ptr()->size());
}
}
// insert the data column to the new block
ColumnWithTypeAndName data_column {col_array.get_data_ptr(), col_type.get_nested_type(),
"R" + array_column_type_name.name};
lambda_block.insert(std::move(data_column));
lambda_datas[i] = col_array.get_data_ptr();
names.push_back("R" + array_column_type_name.name);
data_types.push_back(col_type.get_nested_type());
}
//3. child[0]->execute(new_block)
RETURN_IF_ERROR(children[0]->execute(context, &lambda_block, result_column_id));
ColumnPtr result_col = nullptr;
DataTypePtr res_type;
std::string res_name;
auto res_col = lambda_block.get_by_position(*result_column_id)
.column->convert_to_full_column_if_const();
auto res_type = lambda_block.get_by_position(*result_column_id).type;
auto res_name = lambda_block.get_by_position(*result_column_id).name;
//process first row
args.array_start = (*args.offsets_ptr)[args.current_row_idx - 1];
args.cur_size = (*args.offsets_ptr)[args.current_row_idx] - args.array_start;
while (args.current_row_idx < block->rows()) {
Block lambda_block;
for (int i = 0; i < names.size(); i++) {
ColumnWithTypeAndName data_column;
if (_contains_column_id(args, i) || i >= gap) {
data_column = ColumnWithTypeAndName(data_types[i], names[i]);
} else {
data_column = ColumnWithTypeAndName(
data_types[i]->create_column_const_with_default_value(0), data_types[i],
names[i]);
}
lambda_block.insert(std::move(data_column));
}
MutableColumns columns = lambda_block.mutate_columns();
while (columns[gap]->size() < batch_size) {
long max_step = batch_size - columns[gap]->size();
long current_step =
std::min(max_step, (long)(args.cur_size - args.current_offset_in_array));
size_t pos = args.array_start + args.current_offset_in_array;
for (int i = 0; i < arguments.size(); ++i) {
columns[gap + i]->insert_range_from(*lambda_datas[i], pos, current_step);
}
args.current_offset_in_array += current_step;
args.current_repeat_times += current_step;
if (args.current_offset_in_array >= args.cur_size) {
args.current_row_eos = true;
}
_extend_data(columns, block, args, gap);
if (args.current_row_eos) {
args.current_row_idx++;
args.current_offset_in_array = 0;
if (args.current_row_idx >= block->rows()) {
break;
}
args.current_row_eos = false;
args.array_start = (*args.offsets_ptr)[args.current_row_idx - 1];
args.cur_size = (*args.offsets_ptr)[args.current_row_idx] - args.array_start;
}
}
lambda_block.set_columns(std::move(columns));
//3. child[0]->execute(new_block)
RETURN_IF_ERROR(children[0]->execute(context, &lambda_block, result_column_id));
auto res_col = lambda_block.get_by_position(*result_column_id)
.column->convert_to_full_column_if_const();
res_type = lambda_block.get_by_position(*result_column_id).type;
res_name = lambda_block.get_by_position(*result_column_id).name;
if (!result_col) {
result_col = std::move(res_col);
} else {
MutableColumnPtr column = (*std::move(result_col)).mutate();
column->insert_range_from(*res_col, 0, res_col->size());
}
}
//4. get the result column after execution, reassemble it into a new array column, and return.
ColumnWithTypeAndName result_arr;
if (result_type->is_nullable()) {
if (res_type->is_nullable()) {
result_arr = {ColumnNullable::create(
ColumnArray::create(res_col, std::move(array_column_offset)),
std::move(outside_null_map)),
result_type, res_name};
result_arr = {
ColumnNullable::create(
ColumnArray::create(result_col, std::move(array_column_offset)),
std::move(outside_null_map)),
result_type, res_name};
} else {
// deal with eg: select array_map(x -> x is null, [null, 1, 2]);
// need to create the nested column null map for column array
auto nested_null_map = ColumnUInt8::create(res_col->size(), 0);
auto nested_null_map = ColumnUInt8::create(result_col->size(), 0);
result_arr = {
ColumnNullable::create(
ColumnArray::create(
ColumnNullable::create(res_col, std::move(nested_null_map)),
std::move(array_column_offset)),
ColumnArray::create(ColumnNullable::create(
result_col, std::move(nested_null_map)),
std::move(array_column_offset)),
std::move(outside_null_map)),
result_type, res_name};
}
} else {
if (res_type->is_nullable()) {
result_arr = {ColumnArray::create(res_col, std::move(array_column_offset)),
result_arr = {ColumnArray::create(result_col, std::move(array_column_offset)),
result_type, res_name};
} else {
auto nested_null_map = ColumnUInt8::create(res_col->size(), 0);
result_arr = {ColumnArray::create(
ColumnNullable::create(res_col, std::move(nested_null_map)),
std::move(array_column_offset)),
auto nested_null_map = ColumnUInt8::create(result_col->size(), 0);
result_arr = {ColumnArray::create(ColumnNullable::create(
result_col, std::move(nested_null_map)),
std::move(array_column_offset)),
result_type, res_name};
}
}
block->insert(std::move(result_arr));
*result_column_id = block->columns() - 1;
return Status::OK();
}
private:
bool _contains_column_id(LambdaArgs& args, int id) {
const auto it = std::find(args.output_slot_ref_indexs.begin(),
args.output_slot_ref_indexs.end(), id);
return it != args.output_slot_ref_indexs.end();
}
void _set_column_ref_column_id(VExprSPtr expr, int gap) {
for (const auto& child : expr->children()) {
if (child->is_column_ref()) {
auto* ref = static_cast<VColumnRef*>(child.get());
ref->set_gap(gap);
} else {
_set_column_ref_column_id(child, gap);
}
}
}
void _collect_slot_ref_column_id(VExprSPtr expr, LambdaArgs& args) {
for (const auto& child : expr->children()) {
if (child->is_slot_ref()) {
const auto* ref = static_cast<VSlotRef*>(child.get());
args.output_slot_ref_indexs.push_back(ref->column_id());
} else {
_collect_slot_ref_column_id(child, args);
}
}
}
void _extend_data(std::vector<MutableColumnPtr>& columns, Block* block, LambdaArgs& args,
int size) {
if (!args.current_repeat_times || !size) {
return;
}
for (int i = 0; i < size; i++) {
if (_contains_column_id(args, i)) {
auto src_column =
block->get_by_position(i).column->convert_to_full_column_if_const();
columns[i]->insert_many_from(*src_column, args.current_row_idx,
args.current_repeat_times);
} else {
// must be column const
DCHECK(is_column_const(*columns[i]));
columns[i]->resize(columns[i]->size() + args.current_repeat_times);
}
}
args.current_repeat_times = 0;
}
};
void register_function_array_map(doris::vectorized::LambdaFunctionFactory& factory) {

View File

@ -16,6 +16,8 @@
// under the License.
#pragma once
#include <atomic>
#include "runtime/descriptors.h"
#include "runtime/runtime_state.h"
#include "vec/exprs/vexpr.h"
@ -57,7 +59,7 @@ public:
Status execute(VExprContext* context, Block* block, int* result_column_id) override {
DCHECK(_open_finished || _getting_const_col);
*result_column_id = _column_id;
*result_column_id = _column_id + _gap;
return Status::OK();
}
@ -67,6 +69,12 @@ public:
const std::string& expr_name() const override { return _column_name; }
void set_gap(int gap) {
if (_gap == 0) {
_gap = gap;
}
}
std::string debug_string() const override {
std::stringstream out;
out << "VColumnRef(slot_id: " << _column_id << ",column_name: " << _column_name
@ -76,6 +84,7 @@ public:
private:
int _column_id;
std::atomic<int> _gap = 0;
std::string _column_name;
};
} // namespace vectorized

View File

@ -142,6 +142,9 @@ public:
TypeDescriptor type() { return _type; }
bool is_slot_ref() const { return _node_type == TExprNodeType::SLOT_REF; }
bool is_column_ref() const { return _node_type == TExprNodeType::COLUMN_REF; }
virtual bool is_literal() const { return false; }
TExprNodeType::type node_type() const { return _node_type; }

View File

@ -50,6 +50,7 @@ public:
return Status::InternalError("Lambda Function {} is not implemented.",
_fn.name.function_name);
}
RETURN_IF_ERROR(_lambda_function->prepare(state));
_prepare_finished = true;
return Status::OK();
}

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.analyzer;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Suppliers;
@ -62,20 +61,18 @@ public class Scope {
private final Optional<Scope> outerScope;
private final List<Slot> slots;
private final Optional<SubqueryExpr> ownerSubquery;
private final Set<Slot> correlatedSlots;
private final boolean buildNameToSlot;
private final Supplier<ListMultimap<String, Slot>> nameToSlot;
public Scope(List<? extends Slot> slots) {
this(Optional.empty(), slots, Optional.empty());
this(Optional.empty(), slots);
}
/** Scope */
public Scope(Optional<Scope> outerScope, List<? extends Slot> slots, Optional<SubqueryExpr> subqueryExpr) {
public Scope(Optional<Scope> outerScope, List<? extends Slot> slots) {
this.outerScope = Objects.requireNonNull(outerScope, "outerScope can not be null");
this.slots = Utils.fastToImmutableList(Objects.requireNonNull(slots, "slots can not be null"));
this.ownerSubquery = Objects.requireNonNull(subqueryExpr, "subqueryExpr can not be null");
this.correlatedSlots = Sets.newLinkedHashSet();
this.buildNameToSlot = slots.size() > 500;
this.nameToSlot = buildNameToSlot ? Suppliers.memoize(this::buildNameToSlot) : null;
@ -89,10 +86,6 @@ public class Scope {
return outerScope;
}
public Optional<SubqueryExpr> getSubquery() {
return ownerSubquery;
}
public Set<Slot> getCorrelatedSlots() {
return correlatedSlots;
}

View File

@ -963,7 +963,7 @@ public class BindExpression implements AnalysisRuleFactory {
private Scope toScope(CascadesContext cascadesContext, List<? extends Slot> slots) {
Optional<Scope> outerScope = cascadesContext.getOuterScope();
if (outerScope.isPresent()) {
return new Scope(outerScope, slots, outerScope.get().getSubquery());
return new Scope(outerScope, slots);
} else {
return new Scope(slots);
}

View File

@ -796,8 +796,9 @@ public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext
.map(ArrayItemReference::toSlot)
.collect(ImmutableList.toImmutableList());
ExpressionAnalyzer lambdaAnalyzer = new ExpressionAnalyzer(currentPlan, new Scope(boundedSlots),
context == null ? null : context.cascadesContext, true, false) {
ExpressionAnalyzer lambdaAnalyzer = new ExpressionAnalyzer(currentPlan, new Scope(Optional.of(getScope()),
boundedSlots), context == null ? null : context.cascadesContext,
true, true) {
@Override
protected void couldNotFoundColumn(UnboundSlot unboundSlot, String tableName) {
throw new AnalysisException("Unknown lambda slot '"

View File

@ -46,7 +46,6 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
/**
* Use the visitor to iterate sub expression.
@ -188,19 +187,15 @@ class SubExprAnalyzer<T> extends DefaultExpressionRewriter<T> {
}
CascadesContext subqueryContext = CascadesContext.newContextWithCteContext(
cascadesContext, expr.getQueryPlan(), cascadesContext.getCteContext());
Scope subqueryScope = genScopeWithSubquery(expr);
// don't use `getScope()` because we only need `getScope().getOuterScope()` and `getScope().getSlots()`
// otherwise unexpected errors may occur
Scope subqueryScope = new Scope(getScope().getOuterScope(), getScope().getSlots());
subqueryContext.setOuterScope(subqueryScope);
subqueryContext.newAnalyzer().analyze();
return new AnalyzedResult((LogicalPlan) subqueryContext.getRewritePlan(),
subqueryScope.getCorrelatedSlots());
}
private Scope genScopeWithSubquery(SubqueryExpr expr) {
return new Scope(getScope().getOuterScope(),
getScope().getSlots(),
Optional.ofNullable(expr));
}
public Scope getScope() {
return scope;
}

View File

@ -0,0 +1,41 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select_1 --
\N \N \N \N \N
1 2 [1, 2, 3, 4] [null, 2, 3] [1, 1, 1, 1]
2 3 [6, 7, null, 9] [4, null, 6] [1, 1, null, 1]
3 4 \N [4, 5, 6] \N
-- !select_2 --
\N \N \N \N \N
1 2 [1, 2, 3, 4] [null, 2, 3] [1, 1, 1, 1]
2 3 [6, 7, null, 9] [4, null, 6] [1, 1, null, 1]
3 4 \N [4, 5, 6] \N
-- !select_3 --
\N \N \N \N \N
1 2 [1, 2, 3, 4] [null, 2, 3] [1, 1, 1, 1]
2 3 [6, 7, null, 9] [4, null, 6] [1, 1, null, 1]
3 4 \N [4, 5, 6] \N
-- !select_4 --
\N \N \N \N \N
1 2 [1, 2, 3, 4] [null, 2, 3] [1, 1, 1, 1]
2 3 [6, 7, null, 9] [4, null, 6] [1, 1, null, 1]
3 4 \N [4, 5, 6] \N
-- !select_5 --
\N \N \N \N \N
1 2 [1, 2, 3, 4] [null, 2, 3] [1, 1, 1, 1]
2 3 [6, 7, null, 9] [4, null, 6] [1, 1, null, 1]
3 4 \N [4, 5, 6] \N
-- !select_6 --
\N \N \N \N \N
4 5 [6, 7, null, 9] [4, 5, 6, 7] [0, 0, null, 0]
5 6 [10, 11, 12, 13] [8, 9, null, 11] [0, 0, null, 0]
6 7 \N \N \N
-- !select_7 --
4 5 [6, 7, null, 9] [4, 5, 6, 7] [0, 0, null, 0]
5 6 [10, 11, 12, 13] [8, 9, null, 11] [0, 0, null, 0]

View File

@ -75,8 +75,7 @@ suite("test_array_map_function") {
test {
sql"""select c_array1,array_max(array_map(x->countequal(c_array1,x),c_array1)) from array_test2;"""
check{result, exception, startTime, endTime ->
assertTrue(exception != null)
logger.info(exception.message)
assertTrue(exception == null)
}
}

View File

@ -0,0 +1,70 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("test_array_map_function_with_column") {
def tableName = "array_test_with_column"
sql "DROP TABLE IF EXISTS ${tableName}"
sql """
CREATE TABLE IF NOT EXISTS ${tableName} (
`k1` int(11) NULL COMMENT "",
`k2` int(11) NULL COMMENT "",
`c_array1` ARRAY<int(11)> NULL COMMENT "",
`c_array2` ARRAY<int(11)> NULL COMMENT ""
) ENGINE=OLAP
DISTRIBUTED BY HASH(`k1`,`k2`) BUCKETS 3
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);
"""
sql """INSERT INTO ${tableName} values
(1, 2, [1,2,3,4], [null,2,3]),
(2, 3, [6,7,null,9], [4,null,6]),
(3, 4, NULL, [4, 5, 6]),
(NULL, NULL, NULL, NULL);
"""
qt_select_1 "select *,array_map(x->x+k1+k2 > k1*k2,c_array1) from ${tableName} order by k1;"
sql "set batch_size = 1;"
qt_select_2 "select *,array_map(x->x+k1+k2 > k1*k2,c_array1) from ${tableName} order by k1;"
sql "set batch_size = 4;"
qt_select_3 "select *,array_map(x->x+k1+k2 > k1*k2,c_array1) from ${tableName} order by k1;"
sql "set batch_size = 6;"
qt_select_4 "select *,array_map(x->x+k1+k2 > k1*k2,c_array1) from ${tableName} order by k1;"
sql "set batch_size = 8;"
qt_select_5 "select *,array_map(x->x+k1+k2 > k1*k2,c_array1) from ${tableName} order by k1;"
sql "truncate table ${tableName};"
sql """INSERT INTO ${tableName} values
(4, 5, [6,7,null,9], [4,5,6,7]),
(5, 6, [10,11,12,13], [8,9,null,11]),
(6, 7, NULL, NULL),
(NULL, NULL, NULL, NULL);
"""
qt_select_6 "select *,array_map((x,y)->x+k1+k2 > y+k1*k2,c_array1,c_array2) from ${tableName} order by k1;"
qt_select_7 "select *,array_map((x,y)->x+k1+k2 > y+k1*k2,c_array1,c_array2) from ${tableName} where array_count((x,y) -> k1*x>y+k2, c_array1, c_array2) > 1 order by k1;"
sql "DROP TABLE IF EXISTS ${tableName}"
}