[feature](agg) support aggregate function group_array_intersect (#33265)

This commit is contained in:
Chester
2024-04-16 16:25:48 +08:00
committed by yiguolei
parent 07a8f44443
commit 3096150d1b
12 changed files with 1115 additions and 2 deletions

View File

@ -0,0 +1,90 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp
// and modified by Doris
#include "vec/aggregate_functions/aggregate_function_group_array_intersect.h"
namespace doris::vectorized {
IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type,
const DataTypes& argument_types) {
WhichDataType which(nested_type);
if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateTime) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"We don't support array<date> or array<datetime> for "
"group_array_intersect(), please use array<datev2> or array<datetimev2>.");
} else if (which.idx == TypeIndex::DateV2) {
return new AggregateFunctionGroupArrayIntersect<DateV2>(argument_types);
} else if (which.idx == TypeIndex::DateTimeV2) {
return new AggregateFunctionGroupArrayIntersect<DateTimeV2>(argument_types);
} else {
/// Check that we can use plain version of AggregateFunctionGroupArrayIntersectGeneric
if (nested_type->is_value_unambiguously_represented_in_contiguous_memory_region())
return new AggregateFunctionGroupArrayIntersectGeneric<true>(argument_types);
else
return new AggregateFunctionGroupArrayIntersectGeneric<false>(argument_types);
}
}
inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const auto& nested_type = remove_nullable(
dynamic_cast<const DataTypeArray&>(*(argument_types[0])).get_nested_type());
AggregateFunctionPtr res = nullptr;
WhichDataType which(nested_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
res = creator_without_type::create<AggregateFunctionGroupArrayIntersect<TYPE>>( \
argument_types, result_is_nullable);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (!res) {
res = AggregateFunctionPtr(create_with_extra_types(nested_type, argument_types));
}
if (!res) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"Illegal type {} of argument for aggregate function {}",
argument_types[0]->get_name(), name);
}
return res;
}
AggregateFunctionPtr create_aggregate_function_group_array_intersect(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
assert_unary(name, argument_types);
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
if (!WhichDataType(argument_type).is_array())
throw Exception(ErrorCode::INVALID_ARGUMENT,
"Aggregate function groupArrayIntersect accepts only array type argument. "
"Provided argument type: " +
argument_type->get_name());
return create_aggregate_function_group_array_intersect_impl(name, {argument_type},
result_is_nullable);
}
void register_aggregate_function_group_array_intersect(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("group_array_intersect",
create_aggregate_function_group_array_intersect);
}
} // namespace doris::vectorized

View File

@ -0,0 +1,526 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp
// and modified by Doris
#include <cassert>
#include <memory>
#include "exprs/hybrid_set.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/factory_helpers.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/columns/column_array.h"
#include "vec/common/assert_cast.h"
#include "vec/core/field.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_number.h"
#include "vec/data_types/data_type_string.h"
#include "vec/data_types/data_type_time_v2.h"
#include "vec/io/io_helper.h"
#include "vec/io/var_int.h"
namespace doris::vectorized {
class Arena;
class BufferReadable;
class BufferWritable;
} // namespace doris::vectorized
namespace doris::vectorized {
/// Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType so that to inherit HybridSet
template <typename T>
constexpr PrimitiveType type_to_primitive_type() {
if constexpr (std::is_same_v<T, UInt8> || std::is_same_v<T, Int8>) {
return TYPE_TINYINT;
} else if constexpr (std::is_same_v<T, Int16>) {
return TYPE_SMALLINT;
} else if constexpr (std::is_same_v<T, Int32>) {
return TYPE_INT;
} else if constexpr (std::is_same_v<T, Int64>) {
return TYPE_BIGINT;
} else if constexpr (std::is_same_v<T, Int128>) {
return TYPE_LARGEINT;
} else if constexpr (std::is_same_v<T, Float32>) {
return TYPE_FLOAT;
} else if constexpr (std::is_same_v<T, Float64>) {
return TYPE_DOUBLE;
} else if constexpr (std::is_same_v<T, DateV2>) {
return TYPE_DATEV2;
} else if constexpr (std::is_same_v<T, DateTimeV2>) {
return TYPE_DATETIMEV2;
} else {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType");
}
}
template <typename T>
class NullableNumericOrDateSet : public HybridSet<type_to_primitive_type<T>(),
DynamicContainer<typename PrimitiveTypeTraits<
type_to_primitive_type<T>()>::CppType>> {
public:
NullableNumericOrDateSet() { this->_null_aware = true; }
void change_contains_null_value(bool target_value) { this->_contains_null = target_value; }
};
template <typename T>
struct AggregateFunctionGroupArrayIntersectData {
using ColVecType = ColumnVector<T>;
using NullableNumericOrDateSetType = NullableNumericOrDateSet<T>;
using Set = std::unique_ptr<NullableNumericOrDateSetType>;
AggregateFunctionGroupArrayIntersectData()
: value(std::make_unique<NullableNumericOrDateSetType>()) {}
Set value;
bool init = false;
void process_col_data(auto& column_data, size_t offset, size_t arr_size, bool& init, Set& set) {
const bool is_column_data_nullable = column_data.is_nullable();
const ColumnNullable* col_null = nullptr;
const ColVecType* nested_column_data = nullptr;
if (is_column_data_nullable) {
auto* const_col_data = const_cast<IColumn*>(&column_data);
col_null = static_cast<ColumnNullable*>(const_col_data);
nested_column_data = &assert_cast<const ColVecType&>(col_null->get_nested_column());
} else {
nested_column_data = &static_cast<const ColVecType&>(column_data);
}
if (!init) {
for (size_t i = 0; i < arr_size; ++i) {
const bool is_null_element =
is_column_data_nullable && col_null->is_null_at(offset + i);
const T* src_data =
is_null_element ? nullptr : &(nested_column_data->get_element(offset + i));
set->insert(src_data);
}
init = true;
} else if (set->size() != 0 || set->contain_null()) {
Set new_set = std::make_unique<NullableNumericOrDateSetType>();
for (size_t i = 0; i < arr_size; ++i) {
const bool is_null_element =
is_column_data_nullable && col_null->is_null_at(offset + i);
const T* src_data =
is_null_element ? nullptr : &(nested_column_data->get_element(offset + i));
if (set->find(src_data) || (set->contain_null() && src_data == nullptr)) {
new_set->insert(src_data);
}
}
set = std::move(new_set);
}
}
};
/// Puts all values to the hybrid set. Returns an array of unique values. Implemented for numeric/date types.
template <typename T>
class AggregateFunctionGroupArrayIntersect
: public IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>,
AggregateFunctionGroupArrayIntersect<T>> {
private:
using State = AggregateFunctionGroupArrayIntersectData<T>;
DataTypePtr argument_type;
public:
AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>,
AggregateFunctionGroupArrayIntersect<T>>(
argument_types_),
argument_type(argument_types_[0]) {}
AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_,
const bool result_is_nullable)
: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>,
AggregateFunctionGroupArrayIntersect<T>>(
argument_types_),
argument_type(argument_types_[0]) {}
String get_name() const override { return "group_array_intersect"; }
DataTypePtr get_return_type() const override { return argument_type; }
bool allocates_memory_in_arena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
auto& data = this->data(place);
auto& init = data.init;
auto& set = data.value;
const bool col_is_nullable = (*columns[0]).is_nullable();
const ColumnArray& column =
col_is_nullable ? assert_cast<const ColumnArray&>(
assert_cast<const ColumnNullable&>(*columns[0])
.get_nested_column())
: assert_cast<const ColumnArray&>(*columns[0]);
const auto& offsets = column.get_offsets();
const auto offset = offsets[row_num - 1];
const auto arr_size = offsets[row_num] - offset;
const auto& column_data = column.get_data();
data.process_col_data(column_data, offset, arr_size, init, set);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
auto& data = this->data(place);
auto& set = data.value;
auto& rhs_set = this->data(rhs).value;
if (!this->data(rhs).init) {
return;
}
auto& init = data.init;
if (!init) {
set->change_contains_null_value(rhs_set->contain_null());
HybridSetBase::IteratorBase* it = rhs_set->begin();
while (it->has_next()) {
const void* value = it->get_value();
set->insert(value);
it->next();
}
init = true;
return;
}
if (set->size() != 0) {
auto create_new_set = [](auto& lhs_val, auto& rhs_val) {
typename State::Set new_set =
std::make_unique<typename State::NullableNumericOrDateSetType>();
HybridSetBase::IteratorBase* it = lhs_val->begin();
while (it->has_next()) {
const void* value = it->get_value();
if ((rhs_val->find(value))) {
new_set->insert(value);
}
it->next();
}
new_set->change_contains_null_value(lhs_val->contain_null() &&
rhs_val->contain_null());
return new_set;
};
auto new_set = rhs_set->size() < set->size() ? create_new_set(rhs_set, set)
: create_new_set(set, rhs_set);
set = std::move(new_set);
}
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
auto& data = this->data(place);
auto& set = data.value;
auto& init = data.init;
const bool is_set_contains_null = set->contain_null();
write_pod_binary(is_set_contains_null, buf);
write_pod_binary(init, buf);
write_var_uint(set->size(), buf);
HybridSetBase::IteratorBase* it = set->begin();
while (it->has_next()) {
const T* value_ptr = static_cast<const T*>(it->get_value());
write_int_binary((*value_ptr), buf);
it->next();
}
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
auto& data = this->data(place);
bool is_set_contains_null;
read_pod_binary(is_set_contains_null, buf);
data.value->change_contains_null_value(is_set_contains_null);
read_pod_binary(data.init, buf);
size_t size;
read_var_uint(size, buf);
T element;
for (size_t i = 0; i < size; ++i) {
read_int_binary(element, buf);
data.value->insert(static_cast<void*>(&element));
}
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
ColumnArray& arr_to = assert_cast<ColumnArray&>(to);
ColumnArray::Offsets64& offsets_to = arr_to.get_offsets();
auto& to_nested_col = arr_to.get_data();
const bool is_nullable = to_nested_col.is_nullable();
auto insert_values = [](typename State::ColVecType& nested_col, auto& set,
bool is_nullable = false, ColumnNullable* col_null = nullptr) {
size_t old_size = nested_col.get_data().size();
size_t res_size = set->size();
size_t i = 0;
if (is_nullable && set->contain_null()) {
col_null->insert_data(nullptr, 0);
res_size += 1;
i = 1;
}
nested_col.get_data().resize(old_size + res_size);
HybridSetBase::IteratorBase* it = set->begin();
while (it->has_next()) {
const auto value = *reinterpret_cast<const T*>(it->get_value());
nested_col.get_data()[old_size + i] = value;
if (is_nullable) {
col_null->get_null_map_data().push_back(0);
}
it->next();
++i;
}
};
const auto& set = this->data(place).value;
if (is_nullable) {
auto col_null = reinterpret_cast<ColumnNullable*>(&to_nested_col);
auto& nested_col =
assert_cast<typename State::ColVecType&>(col_null->get_nested_column());
offsets_to.push_back(offsets_to.back() + set->size() + (set->contain_null() ? 1 : 0));
insert_values(nested_col, set, true, col_null);
} else {
auto& nested_col = static_cast<typename State::ColVecType&>(to_nested_col);
offsets_to.push_back(offsets_to.back() + set->size());
insert_values(nested_col, set);
}
}
};
/// Generic implementation, it uses serialized representation as object descriptor.
class NullableStringSet : public StringValueSet<DynamicContainer<StringRef>> {
public:
NullableStringSet() { this->_null_aware = true; }
void change_contains_null_value(bool target_value) { this->_contains_null = target_value; }
};
struct AggregateFunctionGroupArrayIntersectGenericData {
using Set = std::unique_ptr<NullableStringSet>;
AggregateFunctionGroupArrayIntersectGenericData()
: value(std::make_unique<NullableStringSet>()) {}
Set value;
bool init = false;
};
/** Template parameter with true value should be used for columns that store their elements in memory continuously.
* For such columns group_array_intersect() can be implemented more efficiently (especially for small numeric arrays).
*/
template <bool is_plain_column = false>
class AggregateFunctionGroupArrayIntersectGeneric
: public IAggregateFunctionDataHelper<
AggregateFunctionGroupArrayIntersectGenericData,
AggregateFunctionGroupArrayIntersectGeneric<is_plain_column>> {
private:
using State = AggregateFunctionGroupArrayIntersectGenericData;
DataTypePtr input_data_type;
public:
AggregateFunctionGroupArrayIntersectGeneric(const DataTypes& input_data_type_)
: IAggregateFunctionDataHelper<
AggregateFunctionGroupArrayIntersectGenericData,
AggregateFunctionGroupArrayIntersectGeneric<is_plain_column>>(
input_data_type_),
input_data_type(input_data_type_[0]) {}
String get_name() const override { return "group_array_intersect"; }
DataTypePtr get_return_type() const override { return input_data_type; }
bool allocates_memory_in_arena() const override { return true; }
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
auto& data = this->data(place);
auto& init = data.init;
auto& set = data.value;
const bool col_is_nullable = (*columns[0]).is_nullable();
const ColumnArray& column =
col_is_nullable ? assert_cast<const ColumnArray&>(
assert_cast<const ColumnNullable&>(*columns[0])
.get_nested_column())
: assert_cast<const ColumnArray&>(*columns[0]);
const auto nested_column_data = column.get_data_ptr();
const auto& offsets = column.get_offsets();
const auto offset = offsets[row_num - 1];
const auto arr_size = offsets[row_num] - offset;
const auto& column_data = column.get_data();
const bool is_column_data_nullable = column_data.is_nullable();
ColumnNullable* col_null = nullptr;
if (is_column_data_nullable) {
auto const_col_data = const_cast<IColumn*>(&column_data);
col_null = static_cast<ColumnNullable*>(const_col_data);
}
auto process_element = [&](size_t i) {
const bool is_null_element =
is_column_data_nullable && col_null->is_null_at(offset + i);
StringRef src = StringRef();
if constexpr (is_plain_column) {
src = nested_column_data->get_data_at(offset + i);
} else {
const char* begin = nullptr;
src = nested_column_data->serialize_value_into_arena(offset + i, *arena, begin);
}
src.data = is_null_element ? nullptr : arena->insert(src.data, src.size);
return src;
};
if (!init) {
for (size_t i = 0; i < arr_size; ++i) {
StringRef src = process_element(i);
set->insert((void*)src.data, src.size);
}
init = true;
} else if (set->size() != 0 || set->contain_null()) {
typename State::Set new_set = std::make_unique<NullableStringSet>();
for (size_t i = 0; i < arr_size; ++i) {
StringRef src = process_element(i);
if (set->find(src.data, src.size) || (set->contain_null() && src.data == nullptr)) {
new_set->insert((void*)src.data, src.size);
}
}
set = std::move(new_set);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
auto& data = this->data(place);
auto& set = data.value;
auto& rhs_set = this->data(rhs).value;
if (!this->data(rhs).init) {
return;
}
auto& init = data.init;
if (!init) {
set->change_contains_null_value(rhs_set->contain_null());
HybridSetBase::IteratorBase* it = rhs_set->begin();
while (it->has_next()) {
const auto* value = reinterpret_cast<const StringRef*>(it->get_value());
set->insert((void*)(value->data), value->size);
it->next();
}
init = true;
} else if (set->size() != 0) {
auto create_new_set = [](auto& lhs_val, auto& rhs_val) {
typename State::Set new_set = std::make_unique<NullableStringSet>();
HybridSetBase::IteratorBase* it = lhs_val->begin();
while (it->has_next()) {
const auto* value = reinterpret_cast<const StringRef*>(it->get_value());
if (rhs_val->find(value)) {
new_set->insert((void*)value->data, value->size);
}
it->next();
}
new_set->change_contains_null_value(lhs_val->contain_null() &&
rhs_val->contain_null());
return new_set;
};
auto new_set = rhs_set->size() < set->size() ? create_new_set(rhs_set, set)
: create_new_set(set, rhs_set);
set = std::move(new_set);
}
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
auto& data = this->data(place);
auto& set = data.value;
auto& init = data.init;
const bool is_set_contains_null = set->contain_null();
write_pod_binary(is_set_contains_null, buf);
write_pod_binary(init, buf);
write_var_uint(set->size(), buf);
HybridSetBase::IteratorBase* it = set->begin();
while (it->has_next()) {
const auto* value = reinterpret_cast<const StringRef*>(it->get_value());
write_string_binary(*value, buf);
it->next();
}
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena* arena) const override {
auto& data = this->data(place);
bool is_set_contains_null;
read_pod_binary(is_set_contains_null, buf);
data.value->change_contains_null_value(is_set_contains_null);
read_pod_binary(data.init, buf);
size_t size;
read_var_uint(size, buf);
StringRef element;
for (size_t i = 0; i < size; ++i) {
element = read_string_binary_into(*arena, buf);
data.value->insert((void*)element.data, element.size);
}
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& arr_to = assert_cast<ColumnArray&>(to);
ColumnArray::Offsets64& offsets_to = arr_to.get_offsets();
auto& data_to = arr_to.get_data();
auto col_null = reinterpret_cast<ColumnNullable*>(&data_to);
const auto& set = this->data(place).value;
auto res_size = set->size();
if (set->contain_null()) {
col_null->insert_data(nullptr, 0);
res_size += 1;
}
offsets_to.push_back(offsets_to.back() + res_size);
HybridSetBase::IteratorBase* it = set->begin();
while (it->has_next()) {
const auto* value = reinterpret_cast<const StringRef*>(it->get_value());
if constexpr (is_plain_column) {
data_to.insert_data(value->data, value->size);
} else {
std::ignore = data_to.deserialize_and_insert_from_arena(value->data);
}
it->next();
}
}
};
} // namespace doris::vectorized

View File

@ -50,6 +50,7 @@ void register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFact
void register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_group_array_intersect(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory);
@ -81,6 +82,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_uniq(instance);
register_aggregate_function_bit(instance);
register_aggregate_function_bitmap(instance);
register_aggregate_function_group_array_intersect(instance);
register_aggregate_function_group_concat(instance);
register_aggregate_function_quantile_state(instance);
register_aggregate_function_combinator_distinct(instance);

View File

@ -55,7 +55,8 @@ public class AggregateFunction extends Function {
FunctionSet.COUNT, "approx_count_distinct", "ndv", FunctionSet.BITMAP_UNION_INT,
FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", FunctionSet.WINDOW_FUNNEL, FunctionSet.RETENTION,
FunctionSet.SEQUENCE_MATCH, FunctionSet.SEQUENCE_COUNT, FunctionSet.MAP_AGG, FunctionSet.BITMAP_AGG,
FunctionSet.ARRAY_AGG, FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET);
FunctionSet.ARRAY_AGG, FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET,
FunctionSet.GROUP_ARRAY_INTERSECT);
public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx", "first_value",

View File

@ -33,6 +33,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Covar;
import org.apache.doris.nereids.trees.expressions.functions.agg.CovarSamp;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayIntersect;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitOr;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitXor;
@ -103,6 +104,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
agg(CountByEnum.class, "count_by_enum"),
agg(Covar.class, "covar", "covar_pop"),
agg(CovarSamp.class, "covar_samp"),
agg(GroupArrayIntersect.class, "group_array_intersect"),
agg(GroupBitAnd.class, "group_bit_and"),
agg(GroupBitOr.class, "group_bit_or"),
agg(GroupBitXor.class, "group_bit_xor"),

View File

@ -612,6 +612,8 @@ public class FunctionSet<T> {
public static final String GROUP_ARRAY = "group_array";
public static final String GROUP_ARRAY_INTERSECT = "group_array_intersect";
public static final String ARRAY_AGG = "array_agg";
// Populate all the aggregate builtins in the catalog.
@ -1503,7 +1505,9 @@ public class FunctionSet<T> {
addBuiltin(
AggregateFunction.createBuiltin(GROUP_ARRAY, Lists.newArrayList(t, Type.INT), new ArrayType(t),
t, "", "", "", "", "", true, false, true, true));
addBuiltin(
AggregateFunction.createBuiltin(GROUP_ARRAY_INTERSECT, Lists.newArrayList(new ArrayType(t)),
new ArrayType(t), t, "", "", "", "", "", true, false, true, true));
addBuiltin(AggregateFunction.createBuiltin(ARRAY_AGG, Lists.newArrayList(t), new ArrayType(t), t, "", "", "", "", "",
true, false, true, true));

View File

@ -0,0 +1,76 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* AggregateFunction 'group_array_intersect'.
*/
public class GroupArrayIntersect extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0))));
/**
* constructor with 1 argument.
*/
public GroupArrayIntersect(Expression arg) {
super("group_array_intersect", arg);
}
/**
* constructor with 1 argument.
*/
public GroupArrayIntersect(boolean distinct, Expression arg) {
super("group_array_intersect", false, arg);
}
/**
* withChildren.
*/
@Override
public AggregateFunction withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new GroupArrayIntersect(distinct, children.get(0));
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitGroupArrayIntersect(this, context);
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Covar;
import org.apache.doris.nereids.trees.expressions.functions.agg.CovarSamp;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayIntersect;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitOr;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitXor;
@ -168,6 +169,10 @@ public interface AggregateFunctionVisitor<R, C> {
return visitAggregateFunction(multiDistinctSum0, context);
}
default R visitGroupArrayIntersect(GroupArrayIntersect groupArrayIntersect, C context) {
return visitAggregateFunction(groupArrayIntersect, context);
}
default R visitGroupBitAnd(GroupBitAnd groupBitAnd, C context) {
return visitNullableAggregateFunction(groupBitAnd, context);
}

View File

@ -1410,6 +1410,148 @@
-- !sql_count_AnyData_agg_phase_4_notnull --
12 12
-- !sql_group_array_intersect_array_bool --
[0]
-- !sql_group_array_intersect_array_tinyint --
[1]
-- !sql_group_array_intersect_array_smallint --
[]
-- !sql_group_array_intersect_array_int --
[]
[1]
[2]
[3]
[4]
[5]
[6]
[7]
[8]
[9]
[10]
[11]
[12]
-- !sql_group_array_intersect_array_bigint --
[]
-- !sql_group_array_intersect_array_largeint --
[8]
-- !sql_group_array_intersect_array_float --
[5]
-- !sql_group_array_intersect_array_double --
[0.2]
-- !sql_group_array_intersect_array_date --
["2012-03-03"]
-- !sql_group_array_intersect_array_datetime --
["2012-03-04 04:03:04"]
-- !sql_group_array_intersect_array_datev2 --
["2012-03-06"]
-- !sql_group_array_intersect_array_datetimev2 --
["2012-03-09 09:08:09.000000"]
-- !sql_group_array_intersect_array_char --
["char21", "char11", "char31"]
-- !sql_group_array_intersect_array_varchar --
["varchar11", "char11", "varchar31", "char31", "varchar21", "char21"]
-- !sql_group_array_intersect_array_string --
["varchar11", "string1", "varchar31", "char31", "varchar21", "char21"]
-- !sql_group_array_intersect_array_decimal --
[]
[0.100000000]
[0.200000000]
[0.300000000]
[0.400000000]
[0.500000000]
[0.600000000]
[0.700000000]
[0.800000000]
[0.900000000]
[1.000000000]
[1.100000000]
[1.200000000]
-- !sql_group_array_intersect_array_bool_notnull --
[0]
-- !sql_group_array_intersect_array_tinyint_notnull --
[1]
-- !sql_group_array_intersect_array_smallint_notnull --
[]
-- !sql_group_array_intersect_array_int_notnull --
[1]
[2]
[3]
[4]
[5]
[6]
[7]
[8]
[9]
[10]
[11]
[12]
-- !sql_group_array_intersect_array_bigint_notnull --
[]
-- !sql_group_array_intersect_array_largeint_notnull --
[8]
-- !sql_group_array_intersect_array_float_notnull --
[5]
-- !sql_group_array_intersect_array_double_notnull --
[0.2]
-- !sql_group_array_intersect_array_date_notnull --
["2012-03-03"]
-- !sql_group_array_intersect_array_datetime_notnull --
["2012-03-04 04:03:04"]
-- !sql_group_array_intersect_array_datev2_notnull --
["2012-03-06"]
-- !sql_group_array_intersect_array_datetimev2_notnull --
["2012-03-09 09:08:09.000000"]
-- !sql_group_array_intersect_array_char_notnull --
["char21", "char11", "char31"]
-- !sql_group_array_intersect_array_varchar_notnull --
["varchar11", "char11", "varchar31", "char31", "varchar21", "char21"]
-- !sql_group_array_intersect_array_string_notnull --
["varchar11", "string1", "varchar31", "char31", "varchar21", "char21"]
-- !sql_group_array_intersect_array_decimal_notnull --
[0.100000000]
[0.200000000]
[0.300000000]
[0.400000000]
[0.500000000]
[0.600000000]
[0.700000000]
[0.800000000]
[0.900000000]
[1.000000000]
[1.100000000]
[1.200000000]
-- !sql_group_bit_and_TinyInt_gb --
\N
0

View File

@ -0,0 +1,93 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !int_1 --
[null]
-- !int_2 --
[13, 12]
-- !int_3 --
[]
-- !int_4 --
[null]
-- !int_5 --
[7, 6]
-- !int_6 --
[null]
-- !int_7 --
[null, 13, 12]
-- !int_8 --
[]
-- !int_9 --
[]
-- !float_1 --
[6.3, 7.3]
-- !float_2 --
[7.3]
-- !float_3 --
[7.3]
-- !datetimev2_1 --
[]
-- !datetimev2_2 --
["2024-03-24 00:00:00.000"]
-- !datev2_1 --
["2024-03-29"]
-- !datev2_2 --
["2024-05-23"]
-- !string_1 --
[]
-- !string_2 --
["a"]
-- !bigint --
[1234567890123456]
-- !decimal --
[1.34000]
-- !groupby_1 --
0 [0]
1 [4, 1, 5, 2, 3]
-- !groupby_2 --
18 ["c", "e", "b", "d", "a", "f"]
19 ["c", "ff", "cc", "bb", "f", "aa", "dd", "b", "d", "a"]
20 [null, "a"]
21 [null]
22 ["x", "y"]
-- !groupby_3 --
18 ["c", "e", "b", "d", "a", "f"]
-- !notnull_1 --
[]
-- !notnull_2 --
[] []
-- !notnull_3 --
[7.7, 6.6]
-- !notnull_4 --
["c", "b", "d", "a", "f"]
-- !notnull_5 --
[] [] []
-- !notnull_6 --
[]

View File

@ -589,6 +589,72 @@ suite("nereids_agg_fn") {
select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT, TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id, kint), count(kint) from fn_test group by kbool order by kbool'''
qt_sql_count_AnyData_agg_phase_4_notnull '''
select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT, TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id), count(kint) from fn_test'''
qt_sql_group_array_intersect_array_bool '''
select group_array_intersect(kabool) from fn_test where id<= 2;'''
qt_sql_group_array_intersect_array_tinyint '''
select group_array_intersect(katint) from fn_test where id between 7 and 10;'''
qt_sql_group_array_intersect_array_smallint '''
select group_array_intersect(kasint) from fn_test;'''
qt_sql_group_array_intersect_array_int '''
select group_array_intersect(kaint) from fn_test group by id order by id;'''
qt_sql_group_array_intersect_array_bigint '''
select group_array_intersect(kabint) from fn_test where id between 7 and 10;'''
qt_sql_group_array_intersect_array_largeint '''
select group_array_intersect(kalint) from fn_test where id = 7;'''
qt_sql_group_array_intersect_array_float '''
select group_array_intersect(kafloat) from fn_test where id = 4;'''
qt_sql_group_array_intersect_array_double '''
select group_array_intersect(kadbl) from fn_test where id = 1;'''
qt_sql_group_array_intersect_array_date '''
select group_array_intersect(kadt) from fn_test where id = 2;'''
qt_sql_group_array_intersect_array_datetime '''
select group_array_intersect(kadtm) from fn_test where id = 3;'''
qt_sql_group_array_intersect_array_datev2 '''
select group_array_intersect(kadtv2) from fn_test where id = 5;'''
qt_sql_group_array_intersect_array_datetimev2 '''
select group_array_intersect(kadtmv2) from fn_test where id = 8;'''
qt_sql_group_array_intersect_array_char '''
select group_array_intersect(kachr) from fn_test where id in (0, 3);'''
qt_sql_group_array_intersect_array_varchar '''
select group_array_intersect(kavchr) from fn_test where id = 6;'''
qt_sql_group_array_intersect_array_string '''
select group_array_intersect(kastr) from fn_test where id in (6, 9);'''
qt_sql_group_array_intersect_array_decimal '''
select group_array_intersect(kadcml) from fn_test group by id order by id;'''
qt_sql_group_array_intersect_array_bool_notnull '''
select group_array_intersect(kabool) from fn_test_not_nullable where id<= 2;'''
qt_sql_group_array_intersect_array_tinyint_notnull '''
select group_array_intersect(katint) from fn_test_not_nullable where id between 7 and 10;'''
qt_sql_group_array_intersect_array_smallint_notnull '''
select group_array_intersect(kasint) from fn_test_not_nullable;'''
qt_sql_group_array_intersect_array_int_notnull '''
select group_array_intersect(kaint) from fn_test_not_nullable group by id order by id;'''
qt_sql_group_array_intersect_array_bigint_notnull '''
select group_array_intersect(kabint) from fn_test_not_nullable where id between 7 and 10;'''
qt_sql_group_array_intersect_array_largeint_notnull '''
select group_array_intersect(kalint) from fn_test_not_nullable where id = 7;'''
qt_sql_group_array_intersect_array_float_notnull '''
select group_array_intersect(kafloat) from fn_test_not_nullable where id = 4;'''
qt_sql_group_array_intersect_array_double_notnull '''
select group_array_intersect(kadbl) from fn_test_not_nullable where id = 1;'''
qt_sql_group_array_intersect_array_date_notnull '''
select group_array_intersect(kadt) from fn_test_not_nullable where id = 2;'''
qt_sql_group_array_intersect_array_datetime_notnull '''
select group_array_intersect(kadtm) from fn_test_not_nullable where id = 3;'''
qt_sql_group_array_intersect_array_datev2_notnull '''
select group_array_intersect(kadtv2) from fn_test_not_nullable where id = 5;'''
qt_sql_group_array_intersect_array_datetimev2_notnull '''
select group_array_intersect(kadtmv2) from fn_test_not_nullable where id = 8;'''
qt_sql_group_array_intersect_array_char_notnull '''
select group_array_intersect(kachr) from fn_test_not_nullable where id in (0, 3);'''
qt_sql_group_array_intersect_array_varchar_notnull '''
select group_array_intersect(kavchr) from fn_test_not_nullable where id = 6;'''
qt_sql_group_array_intersect_array_string_notnull '''
select group_array_intersect(kastr) from fn_test_not_nullable where id in (6, 9);'''
qt_sql_group_array_intersect_array_decimal_notnull '''
select group_array_intersect(kadcml) from fn_test_not_nullable group by id order by id;'''
qt_sql_group_bit_and_TinyInt_gb '''
select group_bit_and(ktint) from fn_test group by kbool order by kbool'''
qt_sql_group_bit_and_TinyInt '''

View File

@ -0,0 +1,106 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("group_array_intersect") {
sql "DROP TABLE IF EXISTS `group_array_intersect_test`;"
sql """
CREATE TABLE `group_array_intersect_test` (
`id` int(11) NULL COMMENT ""
, `c_array_int` ARRAY<int(11)> NULL COMMENT ""
, `c_array_datetimev2` ARRAY<DATETIMEV2(3)> NULL COMMENT ""
, `c_array_float` ARRAY<float> NULL COMMENT ""
, `c_array_datev2` ARRAY<DATEV2> NULL COMMENT ""
, `c_array_string` ARRAY<string> NULL COMMENT ""
, `c_array_bigint` ARRAY<bigint> NULL COMMENT ""
, `c_array_decimal` ARRAY<decimal(10, 5)> NULL COMMENT ""
) ENGINE=OLAP
DUPLICATE KEY(`id`)
COMMENT "OLAP"
DISTRIBUTED BY HASH(`id`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"in_memory" = "false",
"storage_format" = "V2"
);
"""
sql """INSERT INTO `group_array_intersect_test`(id, c_array_int) VALUES (0, [0]),(1, [1,2,3,4,5]), (2, [6,7,8]), (3, []), (4, null), (5, [6, 7]), (6, [NULL]);"""
sql """INSERT INTO `group_array_intersect_test`(id, c_array_int) VALUES (12, [12, null, 13]), (13, [null, null]), (14, [12, 13]);"""
sql """INSERT INTO `group_array_intersect_test`(id, c_array_float) VALUES (7, [6.3, 7.3]), (8, [7.3, 8.3]), (9, [7.3, 9.3, 8.3]);"""
sql """INSERT INTO `group_array_intersect_test`(id, c_array_datetimev2) VALUES (10, ['2024-03-23 00:00:00', '2024-03-24 00:00:00']), (11, ['2024-03-24 00:00:00', '2024-03-25 00:00:00']);"""
sql """INSERT INTO `group_array_intersect_test`(id, c_array_datev2) VALUES (15, ['2024-05-23', '2024-03-29']), (16, ['2024-03-29', '2024-03-25']), (17, ['2024-05-23', null]);"""
sql """INSERT INTO `group_array_intersect_test`(id, c_array_string) VALUES (18, ['a', 'b', 'c', 'd', 'e', 'f']), (19, ['a', 'aa', 'b', 'bb', 'c', 'cc', 'd', 'dd', 'f', 'ff']);"""
sql """INSERT INTO `group_array_intersect_test`(id, c_array_string) VALUES (20, ['a', null]), (21, [null, null]), (22, ['x', 'y']);"""
sql """INSERT INTO `group_array_intersect_test`(id, c_array_bigint) VALUES (23, [1234567890123456]), (24, [1234567890123456, 2333333333333333]);"""
sql """INSERT INTO `group_array_intersect_test`(id, c_array_decimal) VALUES (25, [1.34,2.00188888888888888]), (26, [1.34,2.00123344444455555]);"""
qt_int_1 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (6, 12);"""
qt_int_2 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (14, 12);"""
qt_int_3 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (0, 6);"""
qt_int_4 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (13);"""
qt_int_5 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (2, 5);"""
qt_int_6 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (6, 13);"""
qt_int_7 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (12);"""
qt_int_8 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (6, 7);"""
qt_int_9 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (9, 12);"""
qt_float_1 """select group_array_intersect(c_array_float) from group_array_intersect_test where id = 7;"""
qt_float_2 """select group_array_intersect(c_array_float) from group_array_intersect_test where id between 7 and 8;"""
qt_float_3 """select group_array_intersect(c_array_float) from group_array_intersect_test where id in (7, 9);"""
qt_datetimev2_1 """select group_array_intersect(c_array_datetimev2) from group_array_intersect_test;"""
qt_datetimev2_2 """select group_array_intersect(c_array_datetimev2) from group_array_intersect_test where id in (10, 11);"""
qt_datev2_1 """select group_array_intersect(c_array_datev2) from group_array_intersect_test where id in (15, 16);"""
qt_datev2_2 """select group_array_intersect(c_array_datev2) from group_array_intersect_test where id in (15, 17);"""
qt_string_1 """select group_array_intersect(c_array_string) from group_array_intersect_test where id in (17, 20);"""
qt_string_2 """select group_array_intersect(c_array_string) from group_array_intersect_test where id in (18, 20);"""
qt_bigint """select group_array_intersect(c_array_bigint) from group_array_intersect_test where id in (23, 24);"""
qt_decimal """select group_array_intersect(c_array_decimal) from group_array_intersect_test where id in (25, 26);"""
qt_groupby_1 """select id, group_array_intersect(c_array_int) from group_array_intersect_test where id <= 1 group by id order by id;"""
qt_groupby_2 """select id, group_array_intersect(c_array_string) from group_array_intersect_test where c_array_string is not null group by id order by id;"""
qt_groupby_3 """select id, group_array_intersect(c_array_string) from group_array_intersect_test where id = 18 group by id order by id;"""
sql "DROP TABLE IF EXISTS `group_array_intersect_test_not_null`;"
sql """
CREATE TABLE `group_array_intersect_test_not_null` (
`id` int(11) NULL COMMENT ""
, `c_array_int` ARRAY<int(11)> NOT NULL COMMENT ""
, `c_array_float` ARRAY<float> NOT NULL COMMENT ""
, `c_array_string` ARRAY<string> NOT NULL COMMENT ""
) ENGINE=OLAP
DUPLICATE KEY(`id`)
COMMENT "OLAP"
DISTRIBUTED BY HASH(`id`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"in_memory" = "false",
"storage_format" = "V2"
);
"""
sql """INSERT INTO `group_array_intersect_test_not_null`(id, c_array_int, c_array_float, c_array_string) VALUES (1, [1, 2, 3, 4, 5], [1.1, 2.2, 3.3, 4.4, 5.5], ['a', 'b', 'c', 'd', 'e', 'f']);"""
sql """INSERT INTO `group_array_intersect_test_not_null`(id, c_array_int, c_array_float, c_array_string) VALUES (2, [6, 7, 8], [6.6, 7.7, 8.8], ['a', 'aa', 'b', 'bb', 'c', 'cc', 'd', 'dd', 'f', 'ff'])"""
sql """INSERT INTO `group_array_intersect_test_not_null`(id, c_array_int, c_array_float, c_array_string) VALUES (3, [], [], []);"""
sql """INSERT INTO `group_array_intersect_test_not_null`(id, c_array_int, c_array_float, c_array_string) VALUES (4, [6, 7], [6.6, 7.7], ['a']);"""
sql """INSERT INTO `group_array_intersect_test_not_null`(id, c_array_int, c_array_float, c_array_string) VALUES (5, [null], [null], ['x', 'y']);"""
qt_notnull_1 """select group_array_intersect(c_array_float) from group_array_intersect_test_not_null where array_size(c_array_float) between 1 and 2;"""
qt_notnull_2 """select group_array_intersect(c_array_int), group_array_intersect(c_array_float) from group_array_intersect_test_not_null where id between 2 and 3;"""
qt_notnull_3 """select group_array_intersect(c_array_float) from group_array_intersect_test_not_null where array_size(c_array_float) between 2 and 3;"""
qt_notnull_4 """select group_array_intersect(c_array_string) from group_array_intersect_test_not_null where id between 1 and 2;"""
qt_notnull_5 """select group_array_intersect(c_array_int), group_array_intersect(c_array_float), group_array_intersect(c_array_string) from group_array_intersect_test_not_null where id between 3 and 4;"""
qt_notnull_6 """select group_array_intersect(c_array_string) from group_array_intersect_test_not_null where id between 1 and 5;"""
}