From ad07dec0ed5a36985d9d747d34a153357c1a90ec Mon Sep 17 00:00:00 2001 From: amory Date: Thu, 22 Feb 2024 11:25:06 +0800 Subject: [PATCH] [Improve](InPredict) enhance in predict with struct type (#30840) --- be/src/vec/exprs/vin_predicate.cpp | 3 + be/src/vec/functions/function_struct_in.cpp | 31 ++++ be/src/vec/functions/function_struct_in.h | 169 ++++++++++++++++++ be/src/vec/functions/in.h | 3 +- .../vec/functions/simple_function_factory.h | 2 + be/test/vec/columns/column_hash_func_test.cpp | 26 +++ .../trees/expressions/InPredicate.java | 10 ++ .../doris/nereids/util/TypeCoercionUtils.java | 11 +- .../inpredicate_with_struct.out | 19 ++ .../inpredicate_with_struct.groovy | 41 +++++ 10 files changed, 311 insertions(+), 4 deletions(-) create mode 100644 be/src/vec/functions/function_struct_in.cpp create mode 100644 be/src/vec/functions/function_struct_in.h create mode 100644 regression-test/data/nereids_syntax_p0/inpredicate_with_struct.out create mode 100644 regression-test/suites/nereids_syntax_p0/inpredicate_with_struct.groovy diff --git a/be/src/vec/exprs/vin_predicate.cpp b/be/src/vec/exprs/vin_predicate.cpp index 896b2a903d..6f57828ef2 100644 --- a/be/src/vec/exprs/vin_predicate.cpp +++ b/be/src/vec/exprs/vin_predicate.cpp @@ -66,6 +66,9 @@ Status VInPredicate::prepare(RuntimeState* state, const RowDescriptor& desc, // construct the proper function_name std::string head(_is_not_in ? "not_" : ""); std::string real_function_name = head + std::string(function_name); + if (is_struct(remove_nullable(argument_template[0].type))) { + real_function_name = "struct_" + real_function_name; + } _function = SimpleFunctionFactory::instance().get_function(real_function_name, argument_template, _data_type); if (_function == nullptr) { diff --git a/be/src/vec/functions/function_struct_in.cpp b/be/src/vec/functions/function_struct_in.cpp new file mode 100644 index 0000000000..943e3a80f4 --- /dev/null +++ b/be/src/vec/functions/function_struct_in.cpp @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// and modified by Doris + +#include "vec/functions/function_struct_in.h" + +#include "vec/functions/simple_function_factory.h" + +namespace doris::vectorized { + +void register_function_struct_in(SimpleFunctionFactory& factory) { + factory.register_function>(); + factory.register_function>(); +} + +} // namespace doris::vectorized diff --git a/be/src/vec/functions/function_struct_in.h b/be/src/vec/functions/function_struct_in.h new file mode 100644 index 0000000000..13ca2b1fec --- /dev/null +++ b/be/src/vec/functions/function_struct_in.h @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include "common/status.h" +#include "vec/columns/column.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_struct.h" +#include "vec/columns/column_vector.h" +#include "vec/columns/columns_number.h" +#include "vec/core/block.h" +#include "vec/data_types/data_type_factory.hpp" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/functions/function.h" + +namespace doris::vectorized { +struct ColumnRowRef { + ColumnPtr column; + size_t row_idx; + + // equals when call set insert, this operator will be used + bool operator==(const ColumnRowRef& other) const { + return column->compare_at(row_idx, other.row_idx, *column, 0) == 0; + } + // compare + bool operator<(const ColumnRowRef& other) const { + return column->compare_at(row_idx, other.row_idx, *column, 0) < 0; + } + + // when call set find, will use hash to find + size_t operator()(const ColumnRowRef& a) const { + uint32_t hash_val = 0; + a.column->update_crc_with_value(a.row_idx, a.row_idx + 1, hash_val, nullptr); + return hash_val; + } +}; + +template +class FunctionStructIn : public IFunction { +public: + static constexpr auto name = negative ? "struct_not_in" : "struct_in"; + + static FunctionPtr create() { return std::make_shared(); } + + String get_name() const override { return name; } + + bool is_variadic() const override { return true; } + + size_t get_number_of_arguments() const override { return 0; } + + DataTypePtr get_return_type_impl(const DataTypes& args) const override { + for (const auto& arg : args) { + if (arg->is_nullable()) { + return make_nullable(std::make_shared()); + } + } + return std::make_shared(); + } + + bool use_default_implementation_for_nulls() const override { return false; } + + // make data in context into a set + Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { + DCHECK(context->get_num_args() >= 1); + auto* col_desc = context->get_arg_type(0); + DataTypePtr args_type = DataTypeFactory::instance().create_data_type(*col_desc); + MutableColumnPtr column_struct_ptr_args = remove_nullable(args_type)->create_column(); + NullMap null_map(context->get_num_args(), false); + for (int i = 1; i < context->get_num_args(); ++i) { + // FE should make element type consistent and + // equalize the length of the elements in struct + const auto& const_column_ptr = context->get_constant_col(i); + if (const_column_ptr == nullptr) { + break; + } + const auto& [col, _] = unpack_if_const(const_column_ptr->column_ptr); + if (col->is_nullable()) { + auto* null_col = vectorized::check_and_get_column(col); + if (null_col->has_null()) { + null_in_set = true; + null_map[i - 1] = true; + } else { + column_struct_ptr_args->insert_from(null_col->get_nested_column(), 0); + } + } else { + column_struct_ptr_args->insert_from(*col, 0); + } + } + ColumnPtr column_ptr = std::move(column_struct_ptr_args); + // make StructRef into set + for (size_t i = 1; i < context->get_num_args(); ++i) { + if (null_in_set && null_map[i - 1]) { + continue; + } + args_set.insert({column_ptr, i - 1}); + } + return Status::OK(); + } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) const override { + auto res = ColumnUInt8::create(); + ColumnUInt8::Container& vec_res = res->get_data(); + vec_res.resize(input_rows_count); + + ColumnUInt8::MutablePtr col_null_map_to; + col_null_map_to = ColumnUInt8::create(input_rows_count, false); + auto& vec_null_map_to = col_null_map_to->get_data(); + + const ColumnWithTypeAndName& left_arg = block.get_by_position(arguments[0]); + const auto& [materialized_column, col_const] = unpack_if_const(left_arg.column); + + for (size_t i = 0; i < input_rows_count; ++i) { + ColumnRowRef ref({materialized_column, i}); + bool find = args_set.find({materialized_column, i}) != args_set.end(); + if constexpr (negative) { + vec_res[i] = !find; + } else { + vec_res[i] = find; + } + if (null_in_set) { + vec_null_map_to[i] = negative == vec_res[i]; + } else { + vec_null_map_to[i] = false; + } + } + + if (block.get_by_position(result).type->is_nullable()) { + block.replace_by_position( + result, ColumnNullable::create(std::move(res), std::move(col_null_map_to))); + } else { + block.replace_by_position(result, std::move(res)); + } + return Status::OK(); + } + +private: + std::unordered_set args_set; + bool null_in_set = false; +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/functions/in.h b/be/src/vec/functions/in.h index 42cb5be616..9fa182caf0 100644 --- a/be/src/vec/functions/in.h +++ b/be/src/vec/functions/in.h @@ -125,7 +125,6 @@ public: state->hybrid_set.reset( create_set(context->get_arg_type(0)->type, get_size_with_out_null(context))); } - for (int i = 1; i < context->get_num_args(); ++i) { const auto& const_column_ptr = context->get_constant_col(i); if (const_column_ptr != nullptr) { @@ -193,7 +192,7 @@ public: } } else { // non-nullable - if (materialized_column->is_column_string()) { + if (WhichDataType(left_arg.type).is_string()) { const auto* column_string_ptr = assert_cast(materialized_column.get()); search_hash_set(in_state, input_rows_count, vec_res, column_string_ptr); diff --git a/be/src/vec/functions/simple_function_factory.h b/be/src/vec/functions/simple_function_factory.h index 9bedc204cb..b1c1b394bf 100644 --- a/be/src/vec/functions/simple_function_factory.h +++ b/be/src/vec/functions/simple_function_factory.h @@ -65,6 +65,7 @@ void register_function_running_difference(SimpleFunctionFactory& factory); void register_function_date_time_to_string(SimpleFunctionFactory& factory); void register_function_date_time_string_to_string(SimpleFunctionFactory& factory); void register_function_in(SimpleFunctionFactory& factory); +void register_function_struct_in(SimpleFunctionFactory& factory); void register_function_if(SimpleFunctionFactory& factory); void register_function_nullif(SimpleFunctionFactory& factory); void register_function_date_time_computation(SimpleFunctionFactory& factory); @@ -245,6 +246,7 @@ public: register_function_time_of_function(instance); register_function_string(instance); register_function_in(instance); + register_function_struct_in(instance); register_function_if(instance); register_function_nullif(instance); register_function_date_time_computation(instance); diff --git a/be/test/vec/columns/column_hash_func_test.cpp b/be/test/vec/columns/column_hash_func_test.cpp index 0a9471a1ab..7b2d5f2ddd 100644 --- a/be/test/vec/columns/column_hash_func_test.cpp +++ b/be/test/vec/columns/column_hash_func_test.cpp @@ -233,4 +233,30 @@ TEST(HashFuncTest, StructTypeTest) { std::cout << crc_hashes[0] << std::endl; } +TEST(HashFuncTest, StructTypeTestWithSepcificValueCrcHash) { + DataTypePtr n1 = std::make_shared(); + DataTypePtr s1 = std::make_shared(); + DataTypes dataTypes; + dataTypes.push_back(n1); + dataTypes.push_back(s1); + + Tuple t; + t.push_back(Int64(1)); + t.push_back(String("hello")); + + DataTypePtr a = std::make_shared(dataTypes); + std::cout << a->get_name() << std::endl; + MutableColumnPtr struct_mutable_col = a->create_column(); + struct_mutable_col->insert(t); + + uint32_t hash_val = 0; + struct_mutable_col->update_crc_with_value(0, 1, hash_val, nullptr); + + for (int i = 0; i < 100; ++i) { + uint32_t should_same_hash_val = 0; + struct_mutable_col->update_crc_with_value(0, 1, should_same_hash_val, nullptr); + EXPECT_EQ(hash_val, should_same_hash_val); + } +} + } // namespace doris::vectorized diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java index c86a074dcf..f3ae9ce5b2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java @@ -76,6 +76,16 @@ public class InPredicate extends Expression { @Override public void checkLegalityBeforeTypeCoercion() { + if (children().get(0).getDataType().isStructType()) { + // we should check in value list is all struct type + for (int i = 1; i < children().size(); i++) { + if (!children().get(i).getDataType().isStructType() && !children().get(i).getDataType().isNullType()) { + throw new AnalysisException("in predicate struct should compare with struct type list, but got : " + + children().get(i).getDataType().toSql()); + } + } + return; + } children().forEach(c -> { if (c.getDataType().isObjectType()) { throw new AnalysisException("in predicate could not contains object type: " + this.toSql()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index a3d4a84d3c..239972fc1d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -953,7 +953,8 @@ public class TypeCoercionUtils { if (inPredicate.getOptions().stream().map(Expression::getDataType) .allMatch(dt -> dt.equals(inPredicate.getCompareExpr().getDataType()))) { - if (!supportCompare(inPredicate.getCompareExpr().getDataType())) { + if (!supportCompare(inPredicate.getCompareExpr().getDataType()) + && !inPredicate.getCompareExpr().getDataType().isStructType()) { throw new AnalysisException("data type " + inPredicate.getCompareExpr().getDataType() + " could not used in InPredicate " + inPredicate.toSql()); } @@ -964,7 +965,13 @@ public class TypeCoercionUtils { .stream() .map(Expression::getDataType).collect(Collectors.toList()), true); - if (optionalCommonType.isPresent() && !supportCompare(optionalCommonType.get())) { + if (inPredicate.getCompareExpr().getDataType().isStructType() && optionalCommonType.isPresent() + && !optionalCommonType.get().isStructType()) { + throw new AnalysisException("data type " + optionalCommonType.get() + + " is not match " + inPredicate.getCompareExpr().getDataType() + " used in InPredicate"); + } + if (optionalCommonType.isPresent() && !supportCompare(optionalCommonType.get()) + && !optionalCommonType.get().isStructType()) { throw new AnalysisException("data type " + optionalCommonType.get() + " could not used in InPredicate " + inPredicate.toSql()); } diff --git a/regression-test/data/nereids_syntax_p0/inpredicate_with_struct.out b/regression-test/data/nereids_syntax_p0/inpredicate_with_struct.out new file mode 100644 index 0000000000..5e1e97a235 --- /dev/null +++ b/regression-test/data/nereids_syntax_p0/inpredicate_with_struct.out @@ -0,0 +1,19 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !in_predicate_11 -- +\N + +-- !in_predicate_12 -- +\N + +-- !in_predicate_13 -- +true + +-- !in_predicate_14 -- +false + +-- !in_predicate_15 -- +true + +-- !in_predicate_16 -- +false + diff --git a/regression-test/suites/nereids_syntax_p0/inpredicate_with_struct.groovy b/regression-test/suites/nereids_syntax_p0/inpredicate_with_struct.groovy new file mode 100644 index 0000000000..fb686fbc9e --- /dev/null +++ b/regression-test/suites/nereids_syntax_p0/inpredicate_with_struct.groovy @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("inpredicate_with_struct") { + sql """ set enable_nereids_planner = true;""" + sql """ set enable_fallback_to_original_planner=false;""" + // support struct type + order_qt_in_predicate_11 """ + select struct(1,"2") in (struct(1,3), null); + """ + order_qt_in_predicate_12 """ + select struct(1,"2") not in (struct(1,3), null); + """ + order_qt_in_predicate_13 """ + select struct(1,"2") in (struct(1,3), struct(1,2)); + """ + order_qt_in_predicate_14 """ + select struct(1,"2") not in (struct(1,3), struct(1,2)); + """ + order_qt_in_predicate_15 """ + select struct(1,"2") in (struct(1,3), struct(1,"2"), struct(1,1)); + """ + order_qt_in_predicate_16 """ + select struct(1,"2") not in (struct(1,3), struct(1,"2"), struct(1,1)); + """ +} +