[Opt](exec) opt aggreate function performance in nullable column
This commit is contained in:
@ -45,11 +45,23 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name,
|
||||
|
||||
AggregateFunctionPtr res;
|
||||
DataTypePtr data_type = argument_types[0];
|
||||
if (is_decimal(data_type)) {
|
||||
res.reset(
|
||||
create_with_decimal_type<AggregateFuncAvg>(*data_type, *data_type, argument_types));
|
||||
if (data_type->is_nullable()) {
|
||||
auto no_null_argument_types = remove_nullable(argument_types);
|
||||
if (is_decimal(no_null_argument_types[0])) {
|
||||
res.reset(create_with_decimal_type_null<AggregateFuncAvg>(
|
||||
no_null_argument_types, parameters, *no_null_argument_types[0],
|
||||
no_null_argument_types));
|
||||
} else {
|
||||
res.reset(create_with_numeric_type_null<AggregateFuncAvg>(
|
||||
no_null_argument_types, parameters, no_null_argument_types));
|
||||
}
|
||||
} else {
|
||||
res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, argument_types));
|
||||
if (is_decimal(data_type)) {
|
||||
res.reset(create_with_decimal_type<AggregateFuncAvg>(*data_type, *data_type,
|
||||
argument_types));
|
||||
} else {
|
||||
res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, argument_types));
|
||||
}
|
||||
}
|
||||
|
||||
if (!res) {
|
||||
@ -61,5 +73,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name,
|
||||
|
||||
void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
|
||||
factory.register_function("avg", create_aggregate_function_avg);
|
||||
factory.register_function("avg", create_aggregate_function_avg, true);
|
||||
}
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -121,7 +121,8 @@ public:
|
||||
DataTypePtr get_serialized_type() const override { return std::make_shared<DataTypeUInt64>(); }
|
||||
};
|
||||
|
||||
/// Simply count number of not-NULL values.
|
||||
// TODO: Maybe AggregateFunctionCountNotNullUnary should be a subclass of AggregateFunctionCount
|
||||
// Simply count number of not-NULL values.
|
||||
class AggregateFunctionCountNotNullUnary final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionCountData,
|
||||
AggregateFunctionCountNotNullUnary> {
|
||||
|
||||
@ -25,7 +25,6 @@
|
||||
#include "vec/aggregate_functions/helpers.h"
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
/// min, max, any
|
||||
template <template <typename, bool> class AggregateFunctionTemplate, template <typename> class Data>
|
||||
static IAggregateFunction* create_aggregate_function_single_value(const String& name,
|
||||
|
||||
@ -85,7 +85,6 @@ public:
|
||||
};
|
||||
|
||||
void register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory& factory) {
|
||||
// factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorNull>());
|
||||
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
|
||||
const Array& params, const bool result_is_nullable) {
|
||||
auto function_combinator = std::make_shared<AggregateFunctionCombinatorNull>();
|
||||
|
||||
@ -40,6 +40,7 @@ namespace doris::vectorized {
|
||||
/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter.
|
||||
/// true - return NULL; false - return value from empty aggregation state of nested function.
|
||||
|
||||
// TODO: only keep class xxxInline after we support all aggregate function
|
||||
template <bool result_is_nullable, typename Derived>
|
||||
class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived> {
|
||||
protected:
|
||||
@ -409,4 +410,248 @@ private:
|
||||
is_nullable; /// Plain array is better than std::vector due to one indirection less.
|
||||
};
|
||||
|
||||
template <typename NestFunction, bool result_is_nullable, typename Derived>
|
||||
class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper<Derived> {
|
||||
protected:
|
||||
std::unique_ptr<NestFunction> nested_function;
|
||||
size_t prefix_size;
|
||||
|
||||
/** In addition to data for nested aggregate function, we keep a flag
|
||||
* indicating - was there at least one non-NULL value accumulated.
|
||||
* In case of no not-NULL values, the function will return NULL.
|
||||
*
|
||||
* We use prefix_size bytes for flag to satisfy the alignment requirement of nested state.
|
||||
*/
|
||||
|
||||
AggregateDataPtr nested_place(AggregateDataPtr __restrict place) const noexcept {
|
||||
return place + prefix_size;
|
||||
}
|
||||
|
||||
ConstAggregateDataPtr nested_place(ConstAggregateDataPtr __restrict place) const noexcept {
|
||||
return place + prefix_size;
|
||||
}
|
||||
|
||||
static void init_flag(AggregateDataPtr __restrict place) noexcept {
|
||||
if constexpr (result_is_nullable) {
|
||||
place[0] = false;
|
||||
}
|
||||
}
|
||||
|
||||
static void set_flag(AggregateDataPtr __restrict place) noexcept {
|
||||
if constexpr (result_is_nullable) {
|
||||
place[0] = true;
|
||||
}
|
||||
}
|
||||
|
||||
static bool get_flag(ConstAggregateDataPtr __restrict place) noexcept {
|
||||
return result_is_nullable ? place[0] : true;
|
||||
}
|
||||
|
||||
public:
|
||||
AggregateFunctionNullBaseInline(IAggregateFunction* nested_function_,
|
||||
const DataTypes& arguments, const Array& params)
|
||||
: IAggregateFunctionHelper<Derived>(arguments, params),
|
||||
nested_function {assert_cast<NestFunction*>(nested_function_)} {
|
||||
if (result_is_nullable) {
|
||||
prefix_size = nested_function->align_of_data();
|
||||
} else {
|
||||
prefix_size = 0;
|
||||
}
|
||||
}
|
||||
|
||||
String get_name() const override {
|
||||
/// This is just a wrapper. The function for Nullable arguments is named the same as the nested function itself.
|
||||
return nested_function->get_name();
|
||||
}
|
||||
|
||||
DataTypePtr get_return_type() const override {
|
||||
return result_is_nullable ? make_nullable(nested_function->get_return_type())
|
||||
: nested_function->get_return_type();
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr __restrict place) const override {
|
||||
init_flag(place);
|
||||
nested_function->create(nested_place(place));
|
||||
}
|
||||
|
||||
void destroy(AggregateDataPtr __restrict place) const noexcept override {
|
||||
nested_function->destroy(nested_place(place));
|
||||
}
|
||||
void reset(AggregateDataPtr place) const override {
|
||||
init_flag(place);
|
||||
nested_function->reset(nested_place(place));
|
||||
}
|
||||
|
||||
bool has_trivial_destructor() const override {
|
||||
return nested_function->has_trivial_destructor();
|
||||
}
|
||||
|
||||
size_t size_of_data() const override { return prefix_size + nested_function->size_of_data(); }
|
||||
|
||||
size_t align_of_data() const override { return nested_function->align_of_data(); }
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
|
||||
Arena* arena) const override {
|
||||
if (result_is_nullable && get_flag(rhs)) {
|
||||
set_flag(place);
|
||||
}
|
||||
|
||||
nested_function->merge(nested_place(place), nested_place(rhs), arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
|
||||
bool flag = get_flag(place);
|
||||
if (result_is_nullable) {
|
||||
write_binary(flag, buf);
|
||||
}
|
||||
if (flag) {
|
||||
nested_function->serialize(nested_place(place), buf);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
|
||||
Arena* arena) const override {
|
||||
bool flag = true;
|
||||
if (result_is_nullable) {
|
||||
read_binary(flag, buf);
|
||||
}
|
||||
if (flag) {
|
||||
set_flag(place);
|
||||
nested_function->deserialize(nested_place(place), buf, arena);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
|
||||
Arena* arena) const override {
|
||||
bool flag = true;
|
||||
if (result_is_nullable) {
|
||||
read_binary(flag, buf);
|
||||
}
|
||||
if (flag) {
|
||||
set_flag(place);
|
||||
nested_function->deserialize_and_merge(nested_place(place), buf, arena);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column,
|
||||
Arena* arena) const override {
|
||||
size_t num_rows = column.size();
|
||||
for (size_t i = 0; i != num_rows; ++i) {
|
||||
VectorBufferReader buffer_reader(
|
||||
(assert_cast<const ColumnString&>(column)).get_data_at(i));
|
||||
deserialize_and_merge(place, buffer_reader, arena);
|
||||
}
|
||||
}
|
||||
|
||||
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
|
||||
if constexpr (result_is_nullable) {
|
||||
ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to);
|
||||
if (get_flag(place)) {
|
||||
nested_function->insert_result_into(nested_place(place),
|
||||
to_concrete.get_nested_column());
|
||||
to_concrete.get_null_map_data().push_back(0);
|
||||
} else {
|
||||
to_concrete.insert_default();
|
||||
}
|
||||
} else {
|
||||
nested_function->insert_result_into(nested_place(place), to);
|
||||
}
|
||||
}
|
||||
|
||||
bool allocates_memory_in_arena() const override {
|
||||
return nested_function->allocates_memory_in_arena();
|
||||
}
|
||||
|
||||
bool is_state() const override { return nested_function->is_state(); }
|
||||
};
|
||||
|
||||
/** There are two cases: for single argument and variadic.
|
||||
* Code for single argument is much more efficient.
|
||||
*/
|
||||
template <typename NestFuction, bool result_is_nullable>
|
||||
class AggregateFunctionNullUnaryInline final
|
||||
: public AggregateFunctionNullBaseInline<
|
||||
NestFuction, result_is_nullable,
|
||||
AggregateFunctionNullUnaryInline<NestFuction, result_is_nullable>> {
|
||||
public:
|
||||
AggregateFunctionNullUnaryInline(IAggregateFunction* nested_function_,
|
||||
const DataTypes& arguments, const Array& params)
|
||||
: AggregateFunctionNullBaseInline<
|
||||
NestFuction, result_is_nullable,
|
||||
AggregateFunctionNullUnaryInline<NestFuction, result_is_nullable>>(
|
||||
nested_function_, arguments, params) {}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
|
||||
Arena* arena) const override {
|
||||
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
|
||||
if (!column->is_null_at(row_num)) {
|
||||
this->set_flag(place);
|
||||
const IColumn* nested_column = &column->get_nested_column();
|
||||
this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena);
|
||||
}
|
||||
}
|
||||
|
||||
void add_not_nullable(AggregateDataPtr __restrict place, const IColumn** columns,
|
||||
size_t row_num, Arena* arena) const {
|
||||
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
|
||||
this->set_flag(place);
|
||||
const IColumn* nested_column = &column->get_nested_column();
|
||||
this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena);
|
||||
}
|
||||
|
||||
void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
|
||||
const IColumn** columns, Arena* arena, bool agg_many) const override {
|
||||
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
|
||||
// The overhead introduced is negligible here, just an extra memory read from NullMap
|
||||
const auto* __restrict null_map_data = column->get_null_map_data().data();
|
||||
const IColumn* nested_column = &column->get_nested_column();
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
if (!null_map_data[i]) {
|
||||
AggregateDataPtr __restrict place = places[i] + place_offset;
|
||||
this->set_flag(place);
|
||||
this->nested_function->add(this->nested_place(place), &nested_column, i, arena);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
|
||||
Arena* arena) const override {
|
||||
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
|
||||
bool has_null = column->has_null();
|
||||
|
||||
if (has_null) {
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
if (!column->is_null_at(i)) {
|
||||
this->set_flag(place);
|
||||
this->add(place, columns, i, arena);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this->set_flag(place);
|
||||
const IColumn* nested_column = &column->get_nested_column();
|
||||
this->nested_function->add_batch_single_place(batch_size, this->nested_place(place),
|
||||
&nested_column, arena);
|
||||
}
|
||||
}
|
||||
|
||||
void add_batch_range(size_t batch_begin, size_t batch_end, AggregateDataPtr place,
|
||||
const IColumn** columns, Arena* arena, bool has_null) override {
|
||||
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
|
||||
|
||||
if (has_null) {
|
||||
for (size_t i = batch_begin; i <= batch_end; ++i) {
|
||||
if (!column->is_null_at(i)) {
|
||||
this->set_flag(place);
|
||||
this->add(place, columns, i, arena);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this->set_flag(place);
|
||||
const IColumn* nested_column = &column->get_nested_column();
|
||||
this->nested_function->add_batch_range(batch_begin, batch_end,
|
||||
this->nested_place(place), &nested_column, arena,
|
||||
false);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include "common/logging.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
|
||||
#include "vec/aggregate_functions/helpers.h"
|
||||
#include "vec/data_types/data_type_nullable.h"
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
@ -45,15 +46,24 @@ AggregateFunctionPtr create_aggregate_function_sum(const std::string& name,
|
||||
const DataTypes& argument_types,
|
||||
const Array& parameters,
|
||||
const bool result_is_nullable) {
|
||||
// assert_no_parameters(name, parameters);
|
||||
// assert_unary(name, argument_types);
|
||||
|
||||
AggregateFunctionPtr res;
|
||||
DataTypePtr data_type = argument_types[0];
|
||||
if (is_decimal(data_type)) {
|
||||
res.reset(create_with_decimal_type<Function>(*data_type, *data_type, argument_types));
|
||||
if (data_type->is_nullable()) {
|
||||
auto no_null_argument_types = remove_nullable(argument_types);
|
||||
if (is_decimal(no_null_argument_types[0])) {
|
||||
res.reset(create_with_decimal_type_null<Function>(no_null_argument_types, parameters,
|
||||
*no_null_argument_types[0],
|
||||
no_null_argument_types));
|
||||
} else {
|
||||
res.reset(create_with_numeric_type_null<Function>(no_null_argument_types, parameters,
|
||||
no_null_argument_types));
|
||||
}
|
||||
} else {
|
||||
res.reset(create_with_numeric_type<Function>(*data_type, argument_types));
|
||||
if (is_decimal(data_type)) {
|
||||
res.reset(create_with_decimal_type<Function>(*data_type, *data_type, argument_types));
|
||||
} else {
|
||||
res.reset(create_with_numeric_type<Function>(*data_type, argument_types));
|
||||
}
|
||||
}
|
||||
|
||||
if (!res) {
|
||||
@ -84,6 +94,8 @@ AggregateFunctionPtr create_aggregate_function_sum_reader(const std::string& nam
|
||||
|
||||
void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) {
|
||||
factory.register_function("sum", create_aggregate_function_sum<AggregateFunctionSumSimple>);
|
||||
factory.register_function("sum", create_aggregate_function_sum<AggregateFunctionSumSimple>,
|
||||
true);
|
||||
}
|
||||
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -21,8 +21,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "vec/aggregate_functions/aggregate_function.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_null.h"
|
||||
#include "vec/data_types/data_type.h"
|
||||
|
||||
// TODO: Should we support decimal in numeric types?
|
||||
#define FOR_NUMERIC_TYPES(M) \
|
||||
M(UInt8) \
|
||||
M(UInt16) \
|
||||
@ -36,6 +38,12 @@
|
||||
M(Float32) \
|
||||
M(Float64)
|
||||
|
||||
#define FOR_DECIMAL_TYPES(M) \
|
||||
M(Decimal32) \
|
||||
M(Decimal64) \
|
||||
M(Decimal128) \
|
||||
M(Decimal128I)
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
/** Create an aggregate function with a numeric type in the template parameter, depending on the type of the argument.
|
||||
@ -49,12 +57,20 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty
|
||||
return new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
if (which.idx == TypeIndex::Enum8) {
|
||||
return new AggregateFunctionTemplate<Int8>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Enum16) {
|
||||
return new AggregateFunctionTemplate<Int16>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
|
||||
static IAggregateFunction* create_with_numeric_type_null(const DataTypes& argument_types,
|
||||
const Array& params, TArgs&&... args) {
|
||||
WhichDataType which(argument_types[0]);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \
|
||||
new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), argument_types, \
|
||||
params);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -68,12 +84,6 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty
|
||||
return new AggregateFunctionTemplate<TYPE, bool_param>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
if (which.idx == TypeIndex::Enum8) {
|
||||
return new AggregateFunctionTemplate<Int8, bool_param>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Enum16) {
|
||||
return new AggregateFunctionTemplate<Int16, bool_param>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -87,12 +97,6 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty
|
||||
return new AggregateFunctionTemplate<TYPE, Data>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
if (which.idx == TypeIndex::Enum8) {
|
||||
return new AggregateFunctionTemplate<Int8, Data>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Enum16) {
|
||||
return new AggregateFunctionTemplate<Int16, Data>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -106,12 +110,6 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty
|
||||
return new AggregateFunctionTemplate<TYPE, Data<TYPE>>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
if (which.idx == TypeIndex::Enum8) {
|
||||
return new AggregateFunctionTemplate<Int8, Data<Int8>>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Enum16) {
|
||||
return new AggregateFunctionTemplate<Int16, Data<Int16>>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -125,51 +123,6 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty
|
||||
return new AggregateFunctionTemplate<Data<TYPE>>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
// if (which.idx == TypeIndex::Enum8) return new AggregateFunctionTemplate<Data<Int8>>(std::forward<TArgs>(args)...);
|
||||
// if (which.idx == TypeIndex::Enum16) return new AggregateFunctionTemplate<Data<Int16>>(std::forward<TArgs>(args)...);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename, typename> class AggregateFunctionTemplate,
|
||||
template <typename> class Data, typename... TArgs>
|
||||
static IAggregateFunction* create_with_unsigned_integer_type(const IDataType& argument_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_type);
|
||||
if (which.idx == TypeIndex::UInt8) {
|
||||
return new AggregateFunctionTemplate<UInt8, Data<UInt8>>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::UInt16) {
|
||||
return new AggregateFunctionTemplate<UInt16, Data<UInt16>>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::UInt32) {
|
||||
return new AggregateFunctionTemplate<UInt32, Data<UInt32>>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::UInt64) {
|
||||
return new AggregateFunctionTemplate<UInt64, Data<UInt64>>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
|
||||
static IAggregateFunction* create_with_numeric_based_type(const IDataType& argument_type,
|
||||
TArgs&&... args) {
|
||||
IAggregateFunction* f = create_with_numeric_type<AggregateFunctionTemplate>(
|
||||
argument_type, std::forward<TArgs>(args)...);
|
||||
if (f) {
|
||||
return f;
|
||||
}
|
||||
|
||||
/// expects that DataTypeDate based on UInt16, DataTypeDateTime based on UInt32 and UUID based on UInt128
|
||||
WhichDataType which(argument_type);
|
||||
if (which.idx == TypeIndex::Date) {
|
||||
return new AggregateFunctionTemplate<UInt16>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::DateTime) {
|
||||
return new AggregateFunctionTemplate<UInt32>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::UUID) {
|
||||
return new AggregateFunctionTemplate<UInt128>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -177,18 +130,25 @@ template <template <typename> class AggregateFunctionTemplate, typename... TArgs
|
||||
static IAggregateFunction* create_with_decimal_type(const IDataType& argument_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_type);
|
||||
if (which.idx == TypeIndex::Decimal32) {
|
||||
return new AggregateFunctionTemplate<Decimal32>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Decimal64) {
|
||||
return new AggregateFunctionTemplate<Decimal64>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Decimal128) {
|
||||
return new AggregateFunctionTemplate<Decimal128>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Decimal128I) {
|
||||
return new AggregateFunctionTemplate<Decimal128I>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
|
||||
FOR_DECIMAL_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
|
||||
static IAggregateFunction* create_with_decimal_type_null(const DataTypes& argument_types,
|
||||
const Array& params, TArgs&&... args) {
|
||||
WhichDataType which(argument_types[0]);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \
|
||||
new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), argument_types, \
|
||||
params);
|
||||
FOR_DECIMAL_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -197,18 +157,11 @@ template <template <typename, typename> class AggregateFunctionTemplate, typenam
|
||||
static IAggregateFunction* create_with_decimal_type(const IDataType& argument_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_type);
|
||||
if (which.idx == TypeIndex::Decimal32) {
|
||||
return new AggregateFunctionTemplate<Decimal32, Data>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Decimal64) {
|
||||
return new AggregateFunctionTemplate<Decimal64, Data>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Decimal128) {
|
||||
return new AggregateFunctionTemplate<Decimal128, Data>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Decimal128I) {
|
||||
return new AggregateFunctionTemplate<Decimal128I, Data>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionTemplate<TYPE, Data>(std::forward<TArgs>(args)...);
|
||||
FOR_DECIMAL_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -224,12 +177,6 @@ static IAggregateFunction* create_with_two_numeric_types_second(const IDataType&
|
||||
return new AggregateFunctionTemplate<FirstType, TYPE>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
if (which.idx == TypeIndex::Enum8) {
|
||||
return new AggregateFunctionTemplate<FirstType, Int8>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Enum16) {
|
||||
return new AggregateFunctionTemplate<FirstType, Int16>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -244,14 +191,6 @@ static IAggregateFunction* create_with_two_numeric_types(const IDataType& first_
|
||||
second_type, std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
if (which.idx == TypeIndex::Enum8) {
|
||||
return create_with_two_numeric_types_second<Int8, AggregateFunctionTemplate>(
|
||||
second_type, std::forward<TArgs>(args)...);
|
||||
}
|
||||
if (which.idx == TypeIndex::Enum16) {
|
||||
return create_with_two_numeric_types_second<Int16, AggregateFunctionTemplate>(
|
||||
second_type, std::forward<TArgs>(args)...);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
@ -158,4 +158,16 @@ DataTypePtr remove_nullable(const DataTypePtr& type) {
|
||||
return type;
|
||||
}
|
||||
|
||||
DataTypes remove_nullable(const DataTypes& types) {
|
||||
DataTypes no_null_types;
|
||||
for (auto& type : types) {
|
||||
if (type->is_nullable()) {
|
||||
no_null_types.push_back(static_cast<const DataTypeNullable&>(*type).get_nested_type());
|
||||
} else {
|
||||
no_null_types.push_back(type);
|
||||
}
|
||||
}
|
||||
return no_null_types;
|
||||
}
|
||||
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -93,5 +93,6 @@ private:
|
||||
|
||||
DataTypePtr make_nullable(const DataTypePtr& type);
|
||||
DataTypePtr remove_nullable(const DataTypePtr& type);
|
||||
DataTypes remove_nullable(const DataTypes& types);
|
||||
|
||||
} // namespace doris::vectorized
|
||||
|
||||
Reference in New Issue
Block a user