[Enchancement](function) refact and optimize some function register (#16955)

refact and optimize some function register
This commit is contained in:
Pxl
2023-02-24 10:05:11 +08:00
committed by GitHub
parent 37b9b038c4
commit c4edea5936
18 changed files with 370 additions and 343 deletions

View File

@ -41,25 +41,8 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name,
const bool result_is_nullable) {
assert_unary(name, argument_types);
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (data_type->is_nullable()) {
auto no_null_argument_types = remove_nullable(argument_types);
if (is_decimal(no_null_argument_types[0])) {
res.reset(create_with_decimal_type_null<AggregateFuncAvg>(
no_null_argument_types, *no_null_argument_types[0], no_null_argument_types));
} else {
res.reset(create_with_numeric_type_null<AggregateFuncAvg>(no_null_argument_types,
no_null_argument_types));
}
} else {
if (is_decimal(data_type)) {
res.reset(create_with_decimal_type<AggregateFuncAvg>(*data_type, *data_type,
argument_types));
} else {
res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, argument_types));
}
}
AggregateFunctionPtr res(
creator_with_type::create<AggregateFuncAvg>(result_is_nullable, argument_types));
if (!res) {
LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate function {}",

View File

@ -95,12 +95,7 @@ public:
/// ctor for native types
AggregateFunctionAvg(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_),
scale(0) {}
/// ctor for Decimals
AggregateFunctionAvg(const IDataType& data_type, const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_),
scale(get_decimal_scale(data_type)) {}
scale(get_decimal_scale(*argument_types_[0])) {}
String get_name() const override { return "avg"; }

View File

@ -21,6 +21,7 @@
#include "vec/aggregate_functions/aggregate_function_bit.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
@ -34,26 +35,10 @@ AggregateFunctionPtr createAggregateFunctionBitwise(const std::string& name,
" is illegal, because it cannot be used in bitwise operations");
}
auto type = argument_types[0].get();
if (type->is_nullable()) {
type = assert_cast<const DataTypeNullable*>(type)->get_nested_type().get();
}
WhichDataType which(*type);
if (which.is_int8()) {
return AggregateFunctionPtr(new AggregateFunctionBitwise<Int8, Data<Int8>>(argument_types));
} else if (which.is_int16()) {
return AggregateFunctionPtr(
new AggregateFunctionBitwise<Int16, Data<Int16>>(argument_types));
} else if (which.is_int32()) {
return AggregateFunctionPtr(
new AggregateFunctionBitwise<Int32, Data<Int32>>(argument_types));
} else if (which.is_int64()) {
return AggregateFunctionPtr(
new AggregateFunctionBitwise<Int64, Data<Int64>>(argument_types));
} else if (which.is_int128()) {
return AggregateFunctionPtr(
new AggregateFunctionBitwise<Int128, Data<Int128>>(argument_types));
AggregateFunctionPtr res(creator_with_integer_type::create<AggregateFunctionBitwise, Data>(
result_is_nullable, argument_types));
if (res) {
return res;
}
LOG(WARNING) << fmt::format("Illegal type " + argument_types[0]->get_name() +
@ -68,6 +53,15 @@ void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory) {
createAggregateFunctionBitwise<AggregateFunctionGroupBitAndData>);
factory.register_function("group_bit_xor",
createAggregateFunctionBitwise<AggregateFunctionGroupBitXorData>);
factory.register_function(
"group_bit_or", createAggregateFunctionBitwise<AggregateFunctionGroupBitOrData>, true);
factory.register_function("group_bit_and",
createAggregateFunctionBitwise<AggregateFunctionGroupBitAndData>,
true);
factory.register_function("group_bit_xor",
createAggregateFunctionBitwise<AggregateFunctionGroupBitXorData>,
true);
}
} // namespace doris::vectorized

View File

@ -49,30 +49,30 @@ public:
return nullptr;
}
AggregateFunctionPtr res;
if (arguments.size() == 1) {
res.reset(create_with_numeric_type<AggregateFunctionDistinct,
AggregateFunctionDistinctSingleNumericData>(
*arguments[0], nested_function, arguments));
AggregateFunctionPtr res(
creator_with_numeric_type::create<AggregateFunctionDistinct,
AggregateFunctionDistinctSingleNumericData>(
result_is_nullable, arguments, nested_function));
if (res) {
return res;
}
if (arguments[0]->is_value_unambiguously_represented_in_contiguous_memory_region()) {
return std::make_shared<AggregateFunctionDistinct<
AggregateFunctionDistinctSingleGenericData<true>>>(nested_function,
arguments);
res.reset(creator_without_type::create<AggregateFunctionDistinct<
AggregateFunctionDistinctSingleGenericData<true>>>(
result_is_nullable, arguments, nested_function));
} else {
return std::make_shared<AggregateFunctionDistinct<
AggregateFunctionDistinctSingleGenericData<false>>>(nested_function,
arguments);
res.reset(creator_without_type::create<AggregateFunctionDistinct<
AggregateFunctionDistinctSingleGenericData<false>>>(
result_is_nullable, arguments, nested_function));
}
return res;
}
return std::make_shared<
AggregateFunctionDistinct<AggregateFunctionDistinctMultipleGenericData>>(
nested_function, arguments);
return AggregateFunctionPtr(
creator_without_type::create<
AggregateFunctionDistinct<AggregateFunctionDistinctMultipleGenericData>>(
result_is_nullable, arguments, nested_function));
}
};
@ -93,5 +93,6 @@ void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact
result_is_nullable);
};
factory.register_distinct_function_combinator(creator, DISTINCT_FUNCTION_PREFIX);
factory.register_distinct_function_combinator(creator, DISTINCT_FUNCTION_PREFIX, true);
}
} // namespace doris::vectorized

View File

@ -26,50 +26,45 @@
namespace doris::vectorized {
/// min, max, any
template <template <typename, bool> class AggregateFunctionTemplate, template <typename> class Data>
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data>
static IAggregateFunction* create_aggregate_function_single_value(const String& name,
const DataTypes& argument_types) {
const DataTypes& argument_types,
const bool result_is_nullable) {
assert_unary(name, argument_types);
const DataTypePtr& argument_type = argument_types[0];
IAggregateFunction* res(creator_with_numeric_type::create<AggregateFunctionTemplate, Data,
SingleValueDataFixed>(
result_is_nullable, argument_types));
if (res) {
return res;
}
res = creator_with_decimal_type::create<AggregateFunctionTemplate, Data,
SingleValueDataDecimal>(result_is_nullable,
argument_types);
if (res) {
return res;
}
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE>>, false>( \
argument_type);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::String) {
return new AggregateFunctionTemplate<Data<SingleValueDataString>, false>(argument_type);
return creator_without_type::create<AggregateFunctionTemplate<Data<SingleValueDataString>>>(
result_is_nullable, argument_types);
}
if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) {
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Int64>>, false>(
argument_type);
return creator_without_type::create<
AggregateFunctionTemplate<Data<SingleValueDataFixed<Int64>>>>(result_is_nullable,
argument_types);
}
if (which.idx == TypeIndex::DateV2) {
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt32>>, false>(
argument_type);
return creator_without_type::create<
AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt32>>>>(result_is_nullable,
argument_types);
}
if (which.idx == TypeIndex::DateTimeV2) {
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt64>>, false>(
argument_type);
}
if (which.idx == TypeIndex::Decimal32) {
return new AggregateFunctionTemplate<Data<SingleValueDataDecimal<Decimal32>>, false>(
argument_type);
}
if (which.idx == TypeIndex::Decimal64) {
return new AggregateFunctionTemplate<Data<SingleValueDataDecimal<Decimal64>>, false>(
argument_type);
}
if (which.idx == TypeIndex::Decimal128) {
return new AggregateFunctionTemplate<Data<SingleValueDataDecimal<Decimal128>>, false>(
argument_type);
}
if (which.idx == TypeIndex::Decimal128I) {
return new AggregateFunctionTemplate<Data<SingleValueDataDecimal<Decimal128I>>, false>(
argument_type);
return creator_without_type::create<
AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt64>>>>(result_is_nullable,
argument_types);
}
return nullptr;
}
@ -79,7 +74,8 @@ AggregateFunctionPtr create_aggregate_function_max(const std::string& name,
const bool result_is_nullable) {
return AggregateFunctionPtr(
create_aggregate_function_single_value<AggregateFunctionsSingleValue,
AggregateFunctionMaxData>(name, argument_types));
AggregateFunctionMaxData>(name, argument_types,
result_is_nullable));
}
AggregateFunctionPtr create_aggregate_function_min(const std::string& name,
@ -87,7 +83,8 @@ AggregateFunctionPtr create_aggregate_function_min(const std::string& name,
const bool result_is_nullable) {
return AggregateFunctionPtr(
create_aggregate_function_single_value<AggregateFunctionsSingleValue,
AggregateFunctionMinData>(name, argument_types));
AggregateFunctionMinData>(name, argument_types,
result_is_nullable));
}
AggregateFunctionPtr create_aggregate_function_any(const std::string& name,
@ -95,7 +92,8 @@ AggregateFunctionPtr create_aggregate_function_any(const std::string& name,
const bool result_is_nullable) {
return AggregateFunctionPtr(
create_aggregate_function_single_value<AggregateFunctionsSingleValue,
AggregateFunctionAnyData>(name, argument_types));
AggregateFunctionAnyData>(name, argument_types,
result_is_nullable));
}
void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory) {
@ -103,6 +101,10 @@ void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory)
factory.register_function("min", create_aggregate_function_min);
factory.register_function("any", create_aggregate_function_any);
factory.register_function("max", create_aggregate_function_max, true);
factory.register_function("min", create_aggregate_function_min, true);
factory.register_function("any", create_aggregate_function_any, true);
factory.register_alias("any", "any_value");
}

View File

@ -485,19 +485,16 @@ struct AggregateFunctionAnyData : Data {
static const char* name() { return "any"; }
};
template <typename Data, bool AllocatesMemoryInArena>
template <typename Data>
class AggregateFunctionsSingleValue final
: public IAggregateFunctionDataHelper<
Data, AggregateFunctionsSingleValue<Data, AllocatesMemoryInArena>> {
: public IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>> {
private:
DataTypePtr& type;
using Base = IAggregateFunctionDataHelper<
Data, AggregateFunctionsSingleValue<Data, AllocatesMemoryInArena>>;
using Base = IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>;
public:
AggregateFunctionsSingleValue(const DataTypePtr& type_)
: IAggregateFunctionDataHelper<
Data, AggregateFunctionsSingleValue<Data, AllocatesMemoryInArena>>({type_}),
AggregateFunctionsSingleValue(const DataTypes& arguments)
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>(arguments),
type(this->argument_types[0]) {
if (StringRef(Data::name()) == StringRef("min") ||
StringRef(Data::name()) == StringRef("max")) {
@ -535,8 +532,6 @@ public:
this->data(place).read(buf);
}
bool allocates_memory_in_arena() const override { return AllocatesMemoryInArena; }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
this->data(place).insert_result_into(to);
}

View File

@ -310,10 +310,7 @@ public:
if (has_null) {
for (size_t i = 0; i < batch_size; ++i) {
if (!column->is_null_at(i)) {
this->set_flag(place);
this->add(place, columns, i, arena);
}
this->add(place, columns, i, arena);
}
} else {
this->set_flag(place);
@ -329,10 +326,7 @@ public:
if (has_null) {
for (size_t i = batch_begin; i <= batch_end; ++i) {
if (!column->is_null_at(i)) {
this->set_flag(place);
this->add(place, columns, i, arena);
}
this->add(place, columns, i, arena);
}
} else {
this->set_flag(place);
@ -587,14 +581,6 @@ public:
}
}
void add_not_nullable(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num, Arena* arena) const {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
this->set_flag(place);
const IColumn* nested_column = &column->get_nested_column();
this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena);
}
void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
const IColumn** columns, Arena* arena, bool agg_many) const override {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
@ -617,10 +603,7 @@ public:
if (has_null) {
for (size_t i = 0; i < batch_size; ++i) {
if (!column->is_null_at(i)) {
this->set_flag(place);
this->add(place, columns, i, arena);
}
this->add(place, columns, i, arena);
}
} else {
this->set_flag(place);
@ -636,10 +619,7 @@ public:
if (has_null) {
for (size_t i = batch_begin; i <= batch_end; ++i) {
if (!column->is_null_at(i)) {
this->set_flag(place);
this->add(place, columns, i, arena);
}
this->add(place, columns, i, arena);
}
} else {
this->set_flag(place);
@ -650,4 +630,70 @@ public:
}
}
};
template <typename NestFuction, bool result_is_nullable>
class AggregateFunctionNullVariadicInline final
: public AggregateFunctionNullBaseInline<
NestFuction, result_is_nullable,
AggregateFunctionNullVariadicInline<NestFuction, result_is_nullable>> {
public:
AggregateFunctionNullVariadicInline(IAggregateFunction* nested_function_,
const DataTypes& arguments)
: AggregateFunctionNullBaseInline<
NestFuction, result_is_nullable,
AggregateFunctionNullVariadicInline<NestFuction, result_is_nullable>>(
nested_function_, arguments),
number_of_arguments(arguments.size()) {
if (number_of_arguments == 1) {
LOG(FATAL)
<< "Logical error: single argument is passed to AggregateFunctionNullVariadic";
}
if (number_of_arguments > MAX_ARGS) {
LOG(FATAL) << fmt::format(
"Maximum number of arguments for aggregate function with Nullable types is {}",
size_t(MAX_ARGS));
}
for (size_t i = 0; i < number_of_arguments; ++i) {
is_nullable[i] = arguments[i]->is_nullable();
}
}
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena* arena) const override {
/// This container stores the columns we really pass to the nested function.
const IColumn* nested_columns[number_of_arguments];
for (size_t i = 0; i < number_of_arguments; ++i) {
if (is_nullable[i]) {
const ColumnNullable& nullable_col =
assert_cast<const ColumnNullable&>(*columns[i]);
if (nullable_col.is_null_at(row_num)) {
/// If at least one column has a null value in the current row,
/// we don't process this row.
return;
}
nested_columns[i] = &nullable_col.get_nested_column();
} else {
nested_columns[i] = columns[i];
}
}
this->set_flag(place);
this->nested_function->add(this->nested_place(place), nested_columns, row_num, arena);
}
bool allocates_memory_in_arena() const override {
return this->nested_function->allocates_memory_in_arena();
}
private:
// The array length is fixed in the implementation of some aggregate functions.
// Therefore we choose 256 as the appropriate maximum length limit.
static const size_t MAX_ARGS = 256;
size_t number_of_arguments = 0;
std::array<char, MAX_ARGS>
is_nullable; /// Plain array is better than std::vector due to one indirection less.
};
} // namespace doris::vectorized

View File

@ -36,18 +36,22 @@ AggregateFunctionPtr create_aggregate_function_orthogonal(const std::string& nam
} else if (argument_types.size() == 1) {
return std::make_shared<AggFunctionOrthBitmapFunc<Impl<StringRef>>>(argument_types);
} else {
const IDataType& argument_type = *argument_types[1];
AggregateFunctionPtr res(create_with_numeric_type<AggFunctionOrthBitmapFunc, Impl>(
argument_type, argument_types));
WhichDataType which(argument_type);
WhichDataType which(*remove_nullable(argument_types[1]));
AggregateFunctionPtr res(
creator_with_type_base<true, true, false, 1>::create<AggFunctionOrthBitmapFunc,
Impl>(result_is_nullable,
argument_types));
if (res) {
return res;
} else if (which.is_string_or_fixed_string()) {
return std::make_shared<AggFunctionOrthBitmapFunc<Impl<std::string_view>>>(
argument_types);
res.reset(
creator_without_type::create<AggFunctionOrthBitmapFunc<Impl<std::string_view>>>(
result_is_nullable, argument_types));
return res;
}
const IDataType& argument_type = *argument_types[1];
LOG(WARNING) << "Incorrect Type " << argument_type.get_name()
<< " of arguments for aggregate function " << name;
return nullptr;
@ -91,5 +95,16 @@ void register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactor
create_aggregate_function_orthogonal_bitmap_union_count);
factory.register_function("intersect_count", create_aggregate_function_intersect_count);
factory.register_function("orthogonal_bitmap_intersect",
create_aggregate_function_orthogonal_bitmap_intersect, true);
factory.register_function("orthogonal_bitmap_intersect_count",
create_aggregate_function_orthogonal_bitmap_intersect_count, true);
factory.register_function("orthogonal_bitmap_union_count",
create_aggregate_function_orthogonal_bitmap_union_count, true);
factory.register_function("intersect_count", create_aggregate_function_intersect_count, true);
}
} // namespace doris::vectorized

View File

@ -92,12 +92,7 @@ public:
AggregateFunctionProduct(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionProduct<T, TResult, Data>>(
argument_types_),
scale(0) {}
AggregateFunctionProduct(const IDataType& data_type, const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionProduct<T, TResult, Data>>(
argument_types_),
scale(get_decimal_scale(data_type)) {}
scale(get_decimal_scale(*argument_types_[0])) {}
DataTypePtr get_return_type() const override {
if constexpr (IsDecimalNumber<T>) {

View File

@ -45,25 +45,8 @@ template <template <typename> class Function>
AggregateFunctionPtr create_aggregate_function_sum(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (data_type->is_nullable()) {
auto no_null_argument_types = remove_nullable(argument_types);
if (is_decimal(no_null_argument_types[0])) {
res.reset(create_with_decimal_type_null<Function>(
no_null_argument_types, *no_null_argument_types[0], no_null_argument_types));
} else {
res.reset(create_with_numeric_type_null<Function>(no_null_argument_types,
no_null_argument_types));
}
} else {
if (is_decimal(data_type)) {
res.reset(create_with_decimal_type<Function>(*data_type, *data_type, argument_types));
} else {
res.reset(create_with_numeric_type<Function>(*data_type, argument_types));
}
}
AggregateFunctionPtr res(
creator_with_type::create<Function>(result_is_nullable, argument_types));
if (!res) {
LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate function {}",
argument_types[0]->get_name(), name);

View File

@ -59,12 +59,7 @@ public:
AggregateFunctionSum(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(
argument_types_),
scale(0) {}
AggregateFunctionSum(const IDataType& data_type, const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(
argument_types_),
scale(get_decimal_scale(data_type)) {}
scale(get_decimal_scale(*argument_types_[0])) {}
DataTypePtr get_return_type() const override {
if constexpr (IsDecimalNumber<T>) {

View File

@ -28,7 +28,7 @@
namespace doris::vectorized {
template <template <typename> class Data, typename DataForVariadic>
template <template <typename> class Data>
AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
@ -38,27 +38,29 @@ AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name,
}
if (argument_types.size() == 1) {
const IDataType& argument_type = *argument_types[0];
AggregateFunctionPtr res(create_with_numeric_type<AggregateFunctionUniq, Data>(
*argument_types[0], argument_types));
const IDataType& argument_type = *remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
// TODO: DateType
AggregateFunctionPtr res(creator_with_numeric_type::create<AggregateFunctionUniq, Data>(
result_is_nullable, argument_types));
if (res) {
return res;
} else if (which.is_decimal32()) {
return std::make_shared<AggregateFunctionUniq<Decimal32, Data<Int32>>>(argument_types);
return AggregateFunctionPtr(
creator_without_type::create<AggregateFunctionUniq<Decimal32, Data<Int32>>>(
result_is_nullable, argument_types));
} else if (which.is_decimal64()) {
return std::make_shared<AggregateFunctionUniq<Decimal64, Data<Int64>>>(argument_types);
} else if (which.is_decimal128()) {
return std::make_shared<AggregateFunctionUniq<Decimal128, Data<Int128>>>(
argument_types);
} else if (which.is_decimal128i()) {
return std::make_shared<AggregateFunctionUniq<Decimal128, Data<Int128>>>(
argument_types);
return AggregateFunctionPtr(
creator_without_type::create<AggregateFunctionUniq<Decimal64, Data<Int64>>>(
result_is_nullable, argument_types));
} else if (which.is_decimal128() || which.is_decimal128i()) {
return AggregateFunctionPtr(
creator_without_type::create<AggregateFunctionUniq<Decimal128, Data<Int128>>>(
result_is_nullable, argument_types));
} else if (which.is_string_or_fixed_string()) {
return std::make_shared<AggregateFunctionUniq<String, Data<String>>>(argument_types);
return AggregateFunctionPtr(
creator_without_type::create<AggregateFunctionUniq<String, Data<String>>>(
result_is_nullable, argument_types));
}
}
@ -67,9 +69,9 @@ AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name,
void register_aggregate_function_uniq(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator =
create_aggregate_function_uniq<AggregateFunctionUniqExactData,
AggregateFunctionUniqExactData<String>>;
create_aggregate_function_uniq<AggregateFunctionUniqExactData>;
factory.register_function("multi_distinct_count", creator);
factory.register_function("multi_distinct_count", creator, true);
}
} // namespace doris::vectorized

View File

@ -23,9 +23,10 @@
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_null.h"
#include "vec/data_types/data_type.h"
#include "vec/utils/template_helpers.hpp"
// TODO: Should we support decimal in numeric types?
#define FOR_NUMERIC_TYPES(M) \
#define FOR_INTEGER_TYPES(M) \
M(UInt8) \
M(UInt16) \
M(UInt32) \
@ -34,10 +35,16 @@
M(Int16) \
M(Int32) \
M(Int64) \
M(Int128) \
M(Float32) \
M(Int128)
#define FOR_FLOAT_TYPES(M) \
M(Float32) \
M(Float64)
#define FOR_NUMERIC_TYPES(M) \
FOR_INTEGER_TYPES(M) \
FOR_FLOAT_TYPES(M)
#define FOR_DECIMAL_TYPES(M) \
M(Decimal32) \
M(Decimal64) \
@ -48,150 +55,162 @@ namespace doris::vectorized {
/** Create an aggregate function with a numeric type in the template parameter, depending on the type of the argument.
*/
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction* create_with_numeric_type(const IDataType& argument_type,
TArgs&&... args) {
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction* create_with_numeric_type_null(const DataTypes& argument_types,
TArgs&&... args) {
WhichDataType which(argument_types[0]);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \
new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), \
argument_types);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <template <typename, bool> class AggregateFunctionTemplate, bool bool_param,
typename... TArgs>
static IAggregateFunction* create_with_numeric_type(const IDataType& argument_type,
TArgs&&... args) {
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<TYPE, bool_param>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate, typename Data,
typename... TArgs>
static IAggregateFunction* create_with_numeric_type(const IDataType& argument_type,
TArgs&&... args) {
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<TYPE, Data>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate,
template <typename> class Data, typename... TArgs>
static IAggregateFunction* create_with_numeric_type(const IDataType& argument_type,
TArgs&&... args) {
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<TYPE, Data<TYPE>>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename Type>
struct BuilderDirect {
using T = AggregateFunctionTemplate<Type>;
};
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
typename... TArgs>
static IAggregateFunction* create_with_numeric_type(const IDataType& argument_type,
TArgs&&... args) {
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<Data<TYPE>>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
typename Type>
struct BuilderData {
using T = AggregateFunctionTemplate<Data<Type>>;
};
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
template <typename> class Impl, typename Type>
struct BuilderDataImpl {
using T = AggregateFunctionTemplate<Data<Impl<Type>>>;
};
template <template <typename, typename> class AggregateFunctionTemplate,
template <typename> class Data, typename Type>
struct BuilderDirectAndData {
using T = AggregateFunctionTemplate<Type, Data<Type>>;
};
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction* create_with_decimal_type(const IDataType& argument_type,
TArgs&&... args) {
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate>
struct CurryDirect {
template <typename Type>
using Builder = BuilderDirect<AggregateFunctionTemplate, Type>;
};
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data>
struct CurryData {
template <typename Type>
using Builder = BuilderData<AggregateFunctionTemplate, Data, Type>;
};
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
template <typename> class Impl>
struct CurryDataImpl {
template <typename Type>
using Builder = BuilderDataImpl<AggregateFunctionTemplate, Data, Impl, Type>;
};
template <template <typename, typename> class AggregateFunctionTemplate,
template <typename> class Data>
struct CurryDirectAndData {
template <typename Type>
using Builder = BuilderDirectAndData<AggregateFunctionTemplate, Data, Type>;
};
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction* create_with_decimal_type_null(const DataTypes& argument_types,
TArgs&&... args) {
WhichDataType which(argument_types[0]);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \
new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), \
argument_types);
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <bool allow_integer, bool allow_float, bool allow_decimal, int define_index = 0>
struct creator_with_type_base {
template <typename Class, typename... TArgs>
static IAggregateFunction* create_base(const bool result_is_nullable,
const DataTypes& argument_types, TArgs&&... args) {
WhichDataType which(remove_nullable(argument_types[define_index]));
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) { \
using T = typename Class::template Builder<TYPE>::T; \
if (have_nullable(argument_types)) { \
IAggregateFunction* result = nullptr; \
if (argument_types.size() > 1) { \
std::visit( \
[&](auto result_is_nullable) { \
result = new AggregateFunctionNullVariadicInline<T, \
result_is_nullable>( \
new T(std::forward<TArgs>(args)..., \
remove_nullable(argument_types)), \
argument_types); \
}, \
make_bool_variant(result_is_nullable)); \
} else { \
std::visit( \
[&](auto result_is_nullable) { \
result = new AggregateFunctionNullUnaryInline<T, result_is_nullable>( \
new T(std::forward<TArgs>(args)..., \
remove_nullable(argument_types)), \
argument_types); \
}, \
make_bool_variant(result_is_nullable)); \
} \
return result; \
} else { \
return new T(std::forward<TArgs>(args)..., argument_types); \
} \
}
template <template <typename, typename> class AggregateFunctionTemplate, typename Data,
typename... TArgs>
static IAggregateFunction* create_with_decimal_type(const IDataType& argument_type,
TArgs&&... args) {
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<TYPE, Data>(std::forward<TArgs>(args)...);
FOR_DECIMAL_TYPES(DISPATCH)
if constexpr (allow_integer) {
FOR_INTEGER_TYPES(DISPATCH);
}
if constexpr (allow_float) {
FOR_FLOAT_TYPES(DISPATCH);
}
if constexpr (allow_decimal) {
FOR_DECIMAL_TYPES(DISPATCH);
}
#undef DISPATCH
return nullptr;
}
return nullptr;
}
/** For template with two arguments.
*/
template <typename FirstType, template <typename, typename> class AggregateFunctionTemplate,
typename... TArgs>
static IAggregateFunction* create_with_two_numeric_types_second(const IDataType& second_type,
TArgs&&... args) {
WhichDataType which(second_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<FirstType, TYPE>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction* create(TArgs&&... args) {
return create_base<CurryDirect<AggregateFunctionTemplate>>(std::forward<TArgs>(args)...);
}
template <template <typename, typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction* create_with_two_numeric_types(const IDataType& first_type,
const IDataType& second_type,
TArgs&&... args) {
WhichDataType which(first_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return create_with_two_numeric_types_second<TYPE, AggregateFunctionTemplate>( \
second_type, std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
typename... TArgs>
static IAggregateFunction* create(TArgs&&... args) {
return create_base<CurryData<AggregateFunctionTemplate, Data>>(
std::forward<TArgs>(args)...);
}
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
template <typename> class Impl, typename... TArgs>
static IAggregateFunction* create(TArgs&&... args) {
return create_base<CurryDataImpl<AggregateFunctionTemplate, Data, Impl>>(
std::forward<TArgs>(args)...);
}
template <template <typename, typename> class AggregateFunctionTemplate,
template <typename> class Data, typename... TArgs>
static IAggregateFunction* create(TArgs&&... args) {
return create_base<CurryDirectAndData<AggregateFunctionTemplate, Data>>(
std::forward<TArgs>(args)...);
}
};
using creator_with_integer_type = creator_with_type_base<true, false, false>;
using creator_with_numeric_type = creator_with_type_base<true, true, false>;
using creator_with_decimal_type = creator_with_type_base<false, false, true>;
using creator_with_type = creator_with_type_base<true, true, true>;
struct creator_without_type {
template <typename AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction* create(const bool result_is_nullable,
const DataTypes& argument_types, TArgs&&... args) {
if (have_nullable(argument_types)) {
IAggregateFunction* result = nullptr;
if (argument_types.size() > 1) {
std::visit(
[&](auto result_is_nullable) {
result = new AggregateFunctionNullVariadicInline<
AggregateFunctionTemplate, result_is_nullable>(
new AggregateFunctionTemplate(std::forward<TArgs>(args)...,
remove_nullable(argument_types)),
argument_types);
},
make_bool_variant(result_is_nullable));
} else {
std::visit(
[&](auto result_is_nullable) {
result = new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate,
result_is_nullable>(
new AggregateFunctionTemplate(std::forward<TArgs>(args)...,
remove_nullable(argument_types)),
argument_types);
},
make_bool_variant(result_is_nullable));
}
return result;
} else {
return new AggregateFunctionTemplate(std::forward<TArgs>(args)..., argument_types);
}
}
};
} // namespace doris::vectorized

View File

@ -255,8 +255,7 @@ inline const DataTypeDecimal<T>* check_decimal(const IDataType& data_type) {
return typeid_cast<const DataTypeDecimal<T>*>(&data_type);
}
inline UInt32 get_decimal_scale(const IDataType& data_type,
UInt32 default_value = std::numeric_limits<UInt32>::max()) {
inline UInt32 get_decimal_scale(const IDataType& data_type, UInt32 default_value = 0) {
if (auto* decimal_type = check_decimal<Decimal32>(data_type)) return decimal_type->get_scale();
if (auto* decimal_type = check_decimal<Decimal64>(data_type)) return decimal_type->get_scale();
if (auto* decimal_type = check_decimal<Decimal128>(data_type)) return decimal_type->get_scale();

View File

@ -161,13 +161,18 @@ DataTypePtr remove_nullable(const DataTypePtr& type) {
DataTypes remove_nullable(const DataTypes& types) {
DataTypes no_null_types;
for (auto& type : types) {
if (type->is_nullable()) {
no_null_types.push_back(static_cast<const DataTypeNullable&>(*type).get_nested_type());
} else {
no_null_types.push_back(type);
}
no_null_types.push_back(remove_nullable(type));
}
return no_null_types;
}
bool have_nullable(const DataTypes& types) {
for (auto& type : types) {
if (type->is_nullable()) {
return true;
}
}
return false;
}
} // namespace doris::vectorized

View File

@ -49,6 +49,10 @@ public:
bool equals(const IDataType& rhs) const override;
bool is_value_unambiguously_represented_in_contiguous_memory_region() const override {
return nested_data_type->is_value_unambiguously_represented_in_contiguous_memory_region();
}
bool get_is_parametric() const override { return true; }
bool have_subtypes() const override { return true; }
bool cannot_be_stored_in_tables() const override {
@ -94,5 +98,6 @@ private:
DataTypePtr make_nullable(const DataTypePtr& type);
DataTypePtr remove_nullable(const DataTypePtr& type);
DataTypes remove_nullable(const DataTypes& types);
bool have_nullable(const DataTypes& types);
} // namespace doris::vectorized

View File

@ -335,7 +335,7 @@ Status get_least_supertype(const DataTypes& types, DataTypePtr* type, bool compa
UInt32 max_scale = 0;
for (const auto& type : types) {
UInt32 scale = get_decimal_scale(*type, 0);
UInt32 scale = get_decimal_scale(*type);
if (scale > max_scale) max_scale = scale;
}

View File

@ -118,15 +118,8 @@ struct AggregateFunction {
static auto create(const DataTypePtr& data_type_ptr) -> AggregateFunctionPtr {
DataTypes data_types = {remove_nullable(data_type_ptr)};
auto& data_type = *data_types.front();
AggregateFunctionPtr nested_function;
if (is_decimal(data_types.front())) {
nested_function = AggregateFunctionPtr(
create_with_decimal_type<Function>(data_type, data_type, data_types));
} else {
nested_function =
AggregateFunctionPtr(create_with_numeric_type<Function>(data_type, data_types));
}
nested_function.reset(creator_with_type::create<Function>(false, data_types));
AggregateFunctionPtr function;
function.reset(new AggregateFunctionNullUnary<true>(nested_function,