[feature-wip](array-type) Add array aggregation functions (#10108)
This commit is contained in:
228
be/test/vec/function/function_array_aggregation_test.cpp
Normal file
228
be/test/vec/function/function_array_aggregation_test.cpp
Normal file
@ -0,0 +1,228 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "function_test_util.h"
|
||||
#include "vec/core/field.h"
|
||||
#include "vec/data_types/data_type_number.h"
|
||||
|
||||
namespace doris {
|
||||
namespace vectorized {
|
||||
|
||||
template <typename T>
|
||||
struct AnyValue {
|
||||
T value {};
|
||||
bool is_null = false;
|
||||
|
||||
AnyValue(T v) : value(std::move(v)) {}
|
||||
|
||||
AnyValue(std::nullptr_t) : is_null(true) {}
|
||||
};
|
||||
|
||||
using IntDataSet = std::vector<std::pair<std::vector<AnyValue<int>>, AnyValue<int>>>;
|
||||
|
||||
template <typename T, typename ReturnType = T>
|
||||
void check_function(const std::string& func_name, const IntDataSet data_set,
|
||||
bool nullable = false) {
|
||||
InputTypeSet input_types;
|
||||
if (!nullable) {
|
||||
input_types = {TypeIndex::Array, ut_type::get_type_index<T>()};
|
||||
} else {
|
||||
input_types = {TypeIndex::Array, TypeIndex::Nullable, ut_type::get_type_index<T>()};
|
||||
}
|
||||
DataSet converted_data_set;
|
||||
for (const auto& row : data_set) {
|
||||
Array array;
|
||||
for (auto any_value : row.first) {
|
||||
if (any_value.is_null) {
|
||||
array.push_back(Field());
|
||||
} else {
|
||||
array.push_back(ut_type::convert_to<T>(any_value.value));
|
||||
}
|
||||
}
|
||||
if (!row.second.is_null) {
|
||||
converted_data_set.emplace_back(std::make_pair<CellSet, Expect>(
|
||||
{array}, ut_type::convert_to<ReturnType>(row.second.value)));
|
||||
} else {
|
||||
converted_data_set.emplace_back(std::make_pair<CellSet, Expect>({array}, Null()));
|
||||
}
|
||||
}
|
||||
check_function<ReturnType, true>(func_name, input_types, converted_data_set);
|
||||
}
|
||||
|
||||
TEST(VFunctionArrayAggregationTest, TestArrayMin) {
|
||||
const std::string func_name = "array_min";
|
||||
IntDataSet data_set = {
|
||||
{{}, nullptr},
|
||||
{{1, 2, 3}, 1},
|
||||
};
|
||||
check_function<DataTypeInt8>(func_name, data_set);
|
||||
check_function<DataTypeInt16>(func_name, data_set);
|
||||
check_function<DataTypeInt32>(func_name, data_set);
|
||||
check_function<DataTypeInt64>(func_name, data_set);
|
||||
check_function<DataTypeInt128>(func_name, data_set);
|
||||
check_function<DataTypeFloat32>(func_name, data_set);
|
||||
check_function<DataTypeFloat64>(func_name, data_set);
|
||||
}
|
||||
|
||||
TEST(VFunctionArrayAggregationTest, TestArrayMinNullable) {
|
||||
const std::string func_name = "array_min";
|
||||
IntDataSet data_set = {
|
||||
{{}, nullptr},
|
||||
{{nullptr}, nullptr},
|
||||
{{1, nullptr, 3}, 1},
|
||||
};
|
||||
check_function<DataTypeInt8>(func_name, data_set, true);
|
||||
check_function<DataTypeInt16>(func_name, data_set, true);
|
||||
check_function<DataTypeInt32>(func_name, data_set, true);
|
||||
check_function<DataTypeInt64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt128>(func_name, data_set, true);
|
||||
check_function<DataTypeFloat32>(func_name, data_set, true);
|
||||
check_function<DataTypeFloat64>(func_name, data_set, true);
|
||||
}
|
||||
|
||||
TEST(VFunctionArrayAggregationTest, TestArrayMax) {
|
||||
const std::string func_name = "array_max";
|
||||
IntDataSet data_set = {
|
||||
{{}, nullptr},
|
||||
{{1, 2, 3}, 3},
|
||||
};
|
||||
check_function<DataTypeInt8>(func_name, data_set);
|
||||
check_function<DataTypeInt16>(func_name, data_set);
|
||||
check_function<DataTypeInt32>(func_name, data_set);
|
||||
check_function<DataTypeInt64>(func_name, data_set);
|
||||
check_function<DataTypeInt128>(func_name, data_set);
|
||||
check_function<DataTypeFloat32>(func_name, data_set);
|
||||
check_function<DataTypeFloat64>(func_name, data_set);
|
||||
}
|
||||
|
||||
TEST(VFunctionArrayAggregationTest, TestArrayMaxNullable) {
|
||||
const std::string func_name = "array_max";
|
||||
IntDataSet data_set = {
|
||||
{{}, nullptr},
|
||||
{{nullptr}, nullptr},
|
||||
{{1, nullptr, 3}, 3},
|
||||
};
|
||||
check_function<DataTypeInt8>(func_name, data_set, true);
|
||||
check_function<DataTypeInt16>(func_name, data_set, true);
|
||||
check_function<DataTypeInt32>(func_name, data_set, true);
|
||||
check_function<DataTypeInt64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt128>(func_name, data_set, true);
|
||||
check_function<DataTypeFloat32>(func_name, data_set, true);
|
||||
check_function<DataTypeFloat64>(func_name, data_set, true);
|
||||
}
|
||||
|
||||
TEST(VFunctionArrayAggregationTest, TestArraySum) {
|
||||
const std::string func_name = "array_sum";
|
||||
IntDataSet data_set = {
|
||||
{{}, nullptr},
|
||||
{{1, 2, 3}, 6},
|
||||
};
|
||||
check_function<DataTypeInt8, DataTypeInt64>(func_name, data_set);
|
||||
check_function<DataTypeInt16, DataTypeInt64>(func_name, data_set);
|
||||
check_function<DataTypeInt32, DataTypeInt64>(func_name, data_set);
|
||||
check_function<DataTypeInt64, DataTypeInt64>(func_name, data_set);
|
||||
check_function<DataTypeInt128, DataTypeInt128>(func_name, data_set);
|
||||
check_function<DataTypeFloat32, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeFloat64, DataTypeFloat64>(func_name, data_set);
|
||||
}
|
||||
|
||||
TEST(VFunctionArrayAggregationTest, TestArraySumNullable) {
|
||||
const std::string func_name = "array_sum";
|
||||
IntDataSet data_set = {
|
||||
{{}, nullptr},
|
||||
{{nullptr}, nullptr},
|
||||
{{1, nullptr, 3}, 4},
|
||||
};
|
||||
check_function<DataTypeInt8, DataTypeInt64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt16, DataTypeInt64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt32, DataTypeInt64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt64, DataTypeInt64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt128, DataTypeInt128>(func_name, data_set, true);
|
||||
check_function<DataTypeFloat32, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeFloat64, DataTypeFloat64>(func_name, data_set, true);
|
||||
}
|
||||
|
||||
TEST(VFunctionArrayAggregationTest, TestArrayAverage) {
|
||||
const std::string func_name = "array_avg";
|
||||
IntDataSet data_set = {
|
||||
{{}, nullptr},
|
||||
{{1, 2, 3}, 2},
|
||||
};
|
||||
check_function<DataTypeInt8, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeInt16, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeInt32, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeInt64, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeInt128, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeFloat32, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeFloat64, DataTypeFloat64>(func_name, data_set);
|
||||
}
|
||||
|
||||
TEST(VFunctionArrayAggregationTest, TestArrayAverageNullable) {
|
||||
const std::string func_name = "array_avg";
|
||||
IntDataSet data_set = {
|
||||
{{}, nullptr},
|
||||
{{nullptr}, nullptr},
|
||||
{{1, nullptr, 3}, 2},
|
||||
};
|
||||
check_function<DataTypeInt8, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt16, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt32, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt64, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt128, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeFloat32, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeFloat64, DataTypeFloat64>(func_name, data_set, true);
|
||||
}
|
||||
|
||||
TEST(VFunctionArrayAggregationTest, TestArrayProduct) {
|
||||
const std::string func_name = "array_product";
|
||||
IntDataSet data_set = {
|
||||
{{}, nullptr},
|
||||
{{1, 2, 3}, 6},
|
||||
};
|
||||
check_function<DataTypeInt8, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeInt16, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeInt32, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeInt64, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeInt128, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeFloat32, DataTypeFloat64>(func_name, data_set);
|
||||
check_function<DataTypeFloat64, DataTypeFloat64>(func_name, data_set);
|
||||
}
|
||||
|
||||
TEST(VFunctionArrayAggregationTest, TestArrayProductNullable) {
|
||||
const std::string func_name = "array_product";
|
||||
IntDataSet data_set = {
|
||||
{{}, nullptr},
|
||||
{{nullptr}, nullptr},
|
||||
{{1, nullptr, 3}, 3},
|
||||
};
|
||||
check_function<DataTypeInt8, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt16, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt32, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt64, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeInt128, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeFloat32, DataTypeFloat64>(func_name, data_set, true);
|
||||
check_function<DataTypeFloat64, DataTypeFloat64>(func_name, data_set, true);
|
||||
}
|
||||
|
||||
} // namespace vectorized
|
||||
} // namespace doris
|
||||
Reference in New Issue
Block a user