[feature](function) Support for aggregate function foreach combiner for some error function (#31913)

Support for aggregate function foreach combiner for some error function
This commit is contained in:
Mryange
2024-03-20 18:13:27 +08:00
committed by yiguolei
parent b6a35d68b0
commit b92a764665
9 changed files with 388 additions and 26 deletions

View File

@ -21,6 +21,7 @@
#include <glog/logging.h>
#include <string.h>
#include <cstddef>
#include <limits>
#include <memory>
#include <new>
@ -300,6 +301,7 @@ template <typename T>
struct AggregateFunctionArrayAggData {
using ElementType = T;
using ColVecType = ColumnVectorOrDecimal<ElementType>;
using Self = AggregateFunctionArrayAggData<T>;
MutableColumnPtr column_data;
ColVecType* nested_column = nullptr;
NullMap* null_map = nullptr;
@ -362,12 +364,56 @@ struct AggregateFunctionArrayAggData {
}
to_arr.get_offsets().push_back(to_nested_col.size());
}
void write(BufferWritable& buf) const {
const size_t size = null_map->size();
write_binary(size, buf);
for (size_t i = 0; i < size; i++) {
write_binary(null_map->data()[i], buf);
}
for (size_t i = 0; i < size; i++) {
write_binary(nested_column->get_data()[i], buf);
}
}
void read(BufferReadable& buf) {
DCHECK(null_map);
DCHECK(null_map->empty());
size_t size = 0;
read_binary(size, buf);
null_map->resize(size);
nested_column->reserve(size);
for (size_t i = 0; i < size; i++) {
read_binary(null_map->data()[i], buf);
}
ElementType data_value;
for (size_t i = 0; i < size; i++) {
read_binary(data_value, buf);
nested_column->get_data().push_back(data_value);
}
}
void merge(const Self& rhs) {
const auto size = rhs.null_map->size();
null_map->resize(size);
nested_column->reserve(size);
for (size_t i = 0; i < size; i++) {
const auto null_value = rhs.null_map->data()[i];
const auto data_value = rhs.nested_column->get_data()[i];
null_map->data()[i] = null_value;
nested_column->get_data().push_back(data_value);
}
}
};
template <>
struct AggregateFunctionArrayAggData<StringRef> {
using ElementType = StringRef;
using ColVecType = ColumnString;
using Self = AggregateFunctionArrayAggData<StringRef>;
MutableColumnPtr column_data;
ColVecType* nested_column = nullptr;
NullMap* null_map = nullptr;
@ -417,6 +463,46 @@ struct AggregateFunctionArrayAggData<StringRef> {
}
to_arr.get_offsets().push_back(to_nested_col.size());
}
void write(BufferWritable& buf) const {
const size_t size = null_map->size();
write_binary(size, buf);
for (size_t i = 0; i < size; i++) {
write_binary(null_map->data()[i], buf);
}
for (size_t i = 0; i < size; i++) {
write_string_binary(nested_column->get_data_at(i), buf);
}
}
void read(BufferReadable& buf) {
DCHECK(null_map);
DCHECK(null_map->empty());
size_t size = 0;
read_binary(size, buf);
null_map->resize(size);
nested_column->reserve(size);
for (size_t i = 0; i < size; i++) {
read_binary(null_map->data()[i], buf);
}
StringRef s;
for (size_t i = 0; i < size; i++) {
read_string_binary(s, buf);
nested_column->insert_data(s.data, s.size);
}
}
void merge(const Self& rhs) {
const auto size = rhs.null_map->size();
null_map->resize(size);
nested_column->reserve(size);
for (size_t i = 0; i < size; i++) {
const auto null_value = rhs.null_map->data()[i];
auto s = rhs.nested_column->get_data_at(i);
null_map->data()[i] = null_value;
nested_column->insert_data(s.data, s.size);
}
}
};
//ShowNull is just used to support array_agg because array_agg needs to display NULL
@ -491,25 +577,21 @@ public:
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena* arena) const override {
auto& data = this->data(place);
auto& rhs_data = this->data(rhs);
const auto& rhs_data = this->data(rhs);
if constexpr (ENABLE_ARENA) {
data.merge(rhs_data, arena);
} else if constexpr (!ShowNull::value) {
} else {
data.merge(rhs_data);
}
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
if constexpr (!ShowNull::value) {
this->data(place).write(buf);
}
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
if constexpr (!ShowNull::value) {
this->data(place).read(buf);
}
this->data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {

View File

@ -135,13 +135,43 @@ struct AggregateFunctionMapAggData {
num_rows);
dst_key_column.get_nested_column().insert_range_from(*_key_column, 0, num_rows);
dst.get_values().insert_range_from(*_value_column, 0, num_rows);
if (offsets.size() == 0) {
if (offsets.empty()) {
offsets.push_back(num_rows);
} else {
offsets.push_back(offsets.back() + num_rows);
}
}
void write(BufferWritable& buf) const {
const size_t size = _key_column->size();
write_binary(size, buf);
for (size_t i = 0; i < size; i++) {
write_binary(assert_cast<KeyColumnType&>(*_key_column).get_data_at(i), buf);
}
for (size_t i = 0; i < size; i++) {
write_binary(_value_column->get_data_at(i), buf);
}
}
void read(BufferReadable& buf) {
size_t size = 0;
read_binary(size, buf);
StringRef key;
for (size_t i = 0; i < size; i++) {
read_binary(key, buf);
if (_map.find(key) != _map.cend()) {
continue;
}
key.data = _arena.insert(key.data, key.size);
assert_cast<KeyColumnType&>(*_key_column).insert_data(key.data, key.size);
}
StringRef val;
for (size_t i = 0; i < size; i++) {
read_binary(val, buf);
_value_column->insert_data(val.data, val.size);
}
}
private:
using KeyColumnType =
std::conditional_t<std::is_same_v<String, K>, ColumnString, ColumnVectorOrDecimal<K>>;
@ -205,16 +235,13 @@ public:
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr /* __restrict place */,
BufferWritable& /* buf */) const override {
LOG(FATAL) << "__builtin_unreachable";
__builtin_unreachable();
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr /* __restrict place */, BufferReadable& /* buf */,
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
LOG(FATAL) << "__builtin_unreachable";
__builtin_unreachable();
this->data(place).read(buf);
}
void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst,

View File

@ -23,13 +23,17 @@
#include <functional>
#include <memory>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>
#include "agent/be_exec_version_manager.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/common/assert_cast.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_nullable.h"
namespace doris::vectorized {
using DataTypePtr = std::shared_ptr<const IDataType>;
@ -51,7 +55,7 @@ public:
private:
using AggregateFunctions = std::unordered_map<std::string, Creator>;
constexpr static std::string_view combiner_names[] = {"_foreach"};
AggregateFunctions aggregate_functions;
AggregateFunctions nullable_aggregate_functions;
std::unordered_map<std::string, std::string> function_alias;
@ -71,6 +75,23 @@ public:
}
}
static bool is_foreach(const std::string& name) {
constexpr std::string_view suffix = "_foreach";
if (name.length() < suffix.length()) {
return false;
}
return name.substr(name.length() - suffix.length()) == suffix;
}
static bool result_nullable_by_foreach(DataTypePtr& data_type) {
// The return value of the 'foreach' function is 'null' or 'array<type>'.
// The internal function's nullable should depend on whether 'type' is nullable
DCHECK(data_type->is_nullable());
return assert_cast<const DataTypeArray*>(remove_nullable(data_type).get())
->get_nested_type()
->is_nullable();
}
void register_distinct_function_combinator(const Creator& creator, const std::string& prefix,
bool nullable = false) {
auto& functions = nullable ? nullable_aggregate_functions : aggregate_functions;
@ -152,6 +173,9 @@ public:
void register_alias(const std::string& name, const std::string& alias) {
function_alias[alias] = name;
for (const auto& s : combiner_names) {
function_alias[alias + std::string(s)] = name + std::string(s);
}
}
/// @TEMPORARY: for be_exec_version < AGG_FUNCTION_NEW

View File

@ -194,9 +194,16 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc,
_fn.name.function_name);
}
} else {
_function = AggregateFunctionSimpleFactory::instance().get(
_fn.name.function_name, argument_types, _data_type->is_nullable(),
state->be_exec_version(), state->enable_decima256());
if (AggregateFunctionSimpleFactory::is_foreach(_fn.name.function_name)) {
_function = AggregateFunctionSimpleFactory::instance().get(
_fn.name.function_name, argument_types,
AggregateFunctionSimpleFactory::result_nullable_by_foreach(_data_type),
state->be_exec_version(), state->enable_decima256());
} else {
_function = AggregateFunctionSimpleFactory::instance().get(
_fn.name.function_name, argument_types, _data_type->is_nullable(),
state->be_exec_version(), state->enable_decima256());
}
}
if (_function == nullptr) {
return Status::InternalError("Agg Function {} is not implemented", _fn.signature);

View File

@ -0,0 +1,84 @@
---
{
"title": "FOREACH",
"language": "en"
}
---
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
## FOREACH
<version since="2.1.0">
</version>
### description
#### Syntax
`AGGREGATE_FUNCTION_FOREACH(arg...)`
Converts an aggregate function for tables into an aggregate function for arrays that aggregates the corresponding array items and returns an array of results. For example, sum_foreach for the arrays [1, 2], [3, 4, 5]and[6, 7]returns the result [10, 13, 5] after adding together the corresponding array items.
### example
```
mysql [test]>select a , s from db;
+-----------+---------------+
| a | s |
+-----------+---------------+
| [1, 2, 3] | ["ab", "123"] |
| [20] | ["cd"] |
| [100] | ["efg"] |
| NULL | NULL |
| [null, 2] | [null, "c"] |
+-----------+---------------+
mysql [test]>select sum_foreach(a) from db;
+----------------+
| sum_foreach(a) |
+----------------+
| [121, 4, 3] |
+----------------+
mysql [test]>select count_foreach(s) from db;
+------------------+
| count_foreach(s) |
+------------------+
| [3, 2] |
+------------------+
mysql [test]>select array_agg_foreach(a) from db;
+-----------------------------------+
| array_agg_foreach(a) |
+-----------------------------------+
| [[1, 20, 100, null], [2, 2], [3]] |
+-----------------------------------+
mysql [test]>select map_agg_foreach(a,a) from db;
+---------------------------------------+
| map_agg_foreach(a, a) |
+---------------------------------------+
| [{1:1, 20:20, 100:100}, {2:2}, {3:3}] |
+---------------------------------------+
```
### keywords
FOREACH

View File

@ -0,0 +1,82 @@
---
{
"title": "FOREACH",
"language": "zh-CN"
}
---
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
## FOREACH
<version since="2.1.0">
</version>
### description
#### Syntax
`AGGREGATE_FUNCTION_FOREACH(arg...)`
将表的聚合函数转换为聚合相应数组项并返回结果数组的数组的聚合函数。 例如, `sum_foreach` 对于数组 [1, 2], [3, 4, 5]和[6, 7]返回结果 [10, 13, 5] 之后将相应的数组项添加在一起。
### example
```
mysql [test]>select a , s from db;
+-----------+---------------+
| a | s |
+-----------+---------------+
| [1, 2, 3] | ["ab", "123"] |
| [20] | ["cd"] |
| [100] | ["efg"] |
| NULL | NULL |
| [null, 2] | [null, "c"] |
+-----------+---------------+
mysql [test]>select sum_foreach(a) from db;
+----------------+
| sum_foreach(a) |
+----------------+
| [121, 4, 3] |
+----------------+
mysql [test]>select count_foreach(s) from db;
+------------------+
| count_foreach(s) |
+------------------+
| [3, 2] |
+------------------+
mysql [test]>select array_agg_foreach(a) from db;
+-----------------------------------+
| array_agg_foreach(a) |
+-----------------------------------+
| [[1, 20, 100, null], [2, 2], [3]] |
+-----------------------------------+
mysql [test]>select map_agg_foreach(a,a) from db;
+---------------------------------------+
| map_agg_foreach(a, a) |
+---------------------------------------+
| [{1:1, 20:20, 100:100}, {2:2}, {3:3}] |
+---------------------------------------+
```
### keywords
FOREACH

View File

@ -62,7 +62,7 @@ public class ForEachCombinator extends AggregateFunction
@Override
public List<FunctionSignature> getSignatures() {
return nested.getSignatures().stream().map(sig -> {
return sig.withReturnType(ArrayType.of(sig.returnType)).withArgumentTypes(false,
return sig.withReturnType(ArrayType.of(sig.returnType)).withArgumentTypes(sig.hasVarArgs,
sig.argumentsTypes.stream().map(arg -> {
return ArrayType.of(arg);
}).collect(ImmutableList.toImmutableList()));

View File

@ -12,7 +12,10 @@
["{"20":1,"100":1,"1":1}", "{"2":2}", "{"3":1}"] ["{"20":1,"100":1,"1":1}", "{"2":2}", "{"3":1}"] [[100, 20, 1], [2], [3]] [[100, 20, 1], [2], [3]]
-- !sql --
[3, 2, 1] ["[{"cbe":{"100":1,"1":1,"20":1},"notnull":3,"null":1,"all":4}]", "[{"cbe":{"2":2},"notnull":2,"null":0,"all":2}]", "[{"cbe":{"3":1},"notnull":1,"null":0,"all":1}]"]
[3, 2, 1] ["[{"cbe":{"100":1,"1":1,"20":1},"notnull":3,"null":1,"all":4}]", "[{"cbe":{"2":2},"notnull":2,"null":0,"all":2}]", "[{"cbe":{"3":1},"notnull":1,"null":0,"all":1}]"] [3, 1, 1]
-- !sql --
["{"num_buckets":3,"buckets":[{"lower":"1","upper":"1","ndv":1,"count":1,"pre_sum":0},{"lower":"20","upper":"20","ndv":1,"count":1,"pre_sum":1},{"lower":"100","upper":"100","ndv":1,"count":1,"pre_sum":2}]}", "{"num_buckets":1,"buckets":[{"lower":"2","upper":"2","ndv":1,"count":2,"pre_sum":0}]}", "{"num_buckets":1,"buckets":[{"lower":"3","upper":"3","ndv":1,"count":1,"pre_sum":0}]}"]
-- !sql --
[100, 2, 3]
@ -26,3 +29,21 @@
-- !sql --
[0, 2, 3] [117, 2, 3] [113, 0, 3]
-- !sql --
["ab,cd,efg", "123,c", "114514"] ["ababcdabefg", "123123c", "114514"]
-- !sql --
[[1], [1], [1]] [[1, 1], [1, 1], [1, 1]] [[1, 1, 1], [1, 1, 1], [1, 1, 1]] [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
-- !sql --
["ab", "123", "114514"] [1, 2, 3]
-- !sql --
[[100, 20, 1], [2], [3]] [["efg", "cd", "ab"], ["c", "123"], ["114514"]] [[1], [2], [3]]
-- !sql --
[[1, 20, 100], [2, 2], [3]] [["ab", "cd", "efg"], ["123", "c"], ["114514"]] [[1], [2, 2], [3]]
-- !sql --
[{1:1, 20:20, 100:100}, {2:2}, {3:3}] [{1:"ab", 20:"cd", 100:"efg"}, {2:"123"}, {3:"114514"}] [{"ab":"ab", "cd":"cd", "efg":"efg"}, {"123":"123", "c":"c"}, {"114514":"114514"}]

View File

@ -20,10 +20,17 @@ suite("test_agg_foreach") {
// now support min min_by maxmax_by avg avg_weighted sum stddev stddev_samp_foreach variance var_samp
// covar covar_samp corr
// topn topn_array topn_weighted
// count count_by_enum
// count count_by_enum approx_count_distinct
// PERCENTILE PERCENTILE_ARRAY PERCENTILE_APPROX
// histogram
// GROUP_BIT_AND GROUP_BIT_OR GROUP_BIT_XOR
// GROUP_BIT_AND GROUP_BIT_OR GROUP_BIT_XOR
// any_value
// array_agg map_agg
// collect_set collect_list
// retention
// not support
// GROUP_BITMAP_XOR BITMAP_UNION HLL_UNION_AGG GROUPING GROUPING_ID BITMAP_AGG SEQUENCE-MATCH SEQUENCE-COUNT
sql """ set enable_nereids_planner=true;"""
sql """ set enable_fallback_to_original_planner=false;"""
@ -48,7 +55,7 @@ suite("test_agg_foreach") {
"""
sql """
insert into foreach_table values
(1,[1,2,3],[[1],[1,2,3],[2]],["ab","123"]),
(1,[1,2,3],[[1],[1,2,3],[2]],["ab","123","114514"]),
(2,[20],[[2]],["cd"]),
(3,[100],[[1]],["efg"]) ,
(4,null,[null],null),
@ -73,7 +80,11 @@ suite("test_agg_foreach") {
qt_sql """
select count_foreach(a) , count_by_enum_foreach(a) from foreach_table;
select count_foreach(a) , count_by_enum_foreach(a) , approx_count_distinct_foreach(a) from foreach_table;
"""
qt_sql """
select histogram_foreach(a) from foreach_table;
"""
qt_sql """
@ -92,4 +103,28 @@ suite("test_agg_foreach") {
qt_sql """
select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), GROUP_BIT_XOR_foreach(a) from foreach_table;
"""
qt_sql """
select GROUP_CONCAT_foreach(s), GROUP_CONCAT_foreach(s,s) from foreach_table;
"""
qt_sql """
select retention_foreach(a), retention_foreach(a,a ),retention_foreach(a,a,a) , retention_foreach(a,a,a ,a) from foreach_table;
"""
qt_sql """
select any_value_foreach(s), any_value_foreach(a) from foreach_table;
"""
qt_sql """
select collect_set_foreach(a), collect_set_foreach(s) , collect_set_foreach(a,a) from foreach_table;
"""
qt_sql """
select collect_list_foreach(a), collect_list_foreach(s) , collect_list_foreach(a,a) from foreach_table;
"""
qt_sql """
select map_agg_foreach(a,a), map_agg_foreach(a,s) , map_agg_foreach(s,s) from foreach_table;
"""
}