[Opt](exec) opt aggreate function performance in nullable column

This commit is contained in:
HappenLee
2023-02-16 22:26:12 +08:00
committed by GitHub
parent 4c7f19ab02
commit 24ef60b491
9 changed files with 341 additions and 120 deletions

View File

@ -45,11 +45,23 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name,
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (is_decimal(data_type)) {
res.reset(
create_with_decimal_type<AggregateFuncAvg>(*data_type, *data_type, argument_types));
if (data_type->is_nullable()) {
auto no_null_argument_types = remove_nullable(argument_types);
if (is_decimal(no_null_argument_types[0])) {
res.reset(create_with_decimal_type_null<AggregateFuncAvg>(
no_null_argument_types, parameters, *no_null_argument_types[0],
no_null_argument_types));
} else {
res.reset(create_with_numeric_type_null<AggregateFuncAvg>(
no_null_argument_types, parameters, no_null_argument_types));
}
} else {
res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, argument_types));
if (is_decimal(data_type)) {
res.reset(create_with_decimal_type<AggregateFuncAvg>(*data_type, *data_type,
argument_types));
} else {
res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, argument_types));
}
}
if (!res) {
@ -61,5 +73,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name,
void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
factory.register_function("avg", create_aggregate_function_avg);
factory.register_function("avg", create_aggregate_function_avg, true);
}
} // namespace doris::vectorized

View File

@ -121,7 +121,8 @@ public:
DataTypePtr get_serialized_type() const override { return std::make_shared<DataTypeUInt64>(); }
};
/// Simply count number of not-NULL values.
// TODO: Maybe AggregateFunctionCountNotNullUnary should be a subclass of AggregateFunctionCount
// Simply count number of not-NULL values.
class AggregateFunctionCountNotNullUnary final
: public IAggregateFunctionDataHelper<AggregateFunctionCountData,
AggregateFunctionCountNotNullUnary> {

View File

@ -25,7 +25,6 @@
#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
/// min, max, any
template <template <typename, bool> class AggregateFunctionTemplate, template <typename> class Data>
static IAggregateFunction* create_aggregate_function_single_value(const String& name,

View File

@ -85,7 +85,6 @@ public:
};
void register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory& factory) {
// factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorNull>());
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const Array& params, const bool result_is_nullable) {
auto function_combinator = std::make_shared<AggregateFunctionCombinatorNull>();

View File

@ -40,6 +40,7 @@ namespace doris::vectorized {
/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter.
/// true - return NULL; false - return value from empty aggregation state of nested function.
// TODO: only keep class xxxInline after we support all aggregate function
template <bool result_is_nullable, typename Derived>
class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived> {
protected:
@ -409,4 +410,248 @@ private:
is_nullable; /// Plain array is better than std::vector due to one indirection less.
};
template <typename NestFunction, bool result_is_nullable, typename Derived>
class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper<Derived> {
protected:
std::unique_ptr<NestFunction> nested_function;
size_t prefix_size;
/** In addition to data for nested aggregate function, we keep a flag
* indicating - was there at least one non-NULL value accumulated.
* In case of no not-NULL values, the function will return NULL.
*
* We use prefix_size bytes for flag to satisfy the alignment requirement of nested state.
*/
AggregateDataPtr nested_place(AggregateDataPtr __restrict place) const noexcept {
return place + prefix_size;
}
ConstAggregateDataPtr nested_place(ConstAggregateDataPtr __restrict place) const noexcept {
return place + prefix_size;
}
static void init_flag(AggregateDataPtr __restrict place) noexcept {
if constexpr (result_is_nullable) {
place[0] = false;
}
}
static void set_flag(AggregateDataPtr __restrict place) noexcept {
if constexpr (result_is_nullable) {
place[0] = true;
}
}
static bool get_flag(ConstAggregateDataPtr __restrict place) noexcept {
return result_is_nullable ? place[0] : true;
}
public:
AggregateFunctionNullBaseInline(IAggregateFunction* nested_function_,
const DataTypes& arguments, const Array& params)
: IAggregateFunctionHelper<Derived>(arguments, params),
nested_function {assert_cast<NestFunction*>(nested_function_)} {
if (result_is_nullable) {
prefix_size = nested_function->align_of_data();
} else {
prefix_size = 0;
}
}
String get_name() const override {
/// This is just a wrapper. The function for Nullable arguments is named the same as the nested function itself.
return nested_function->get_name();
}
DataTypePtr get_return_type() const override {
return result_is_nullable ? make_nullable(nested_function->get_return_type())
: nested_function->get_return_type();
}
void create(AggregateDataPtr __restrict place) const override {
init_flag(place);
nested_function->create(nested_place(place));
}
void destroy(AggregateDataPtr __restrict place) const noexcept override {
nested_function->destroy(nested_place(place));
}
void reset(AggregateDataPtr place) const override {
init_flag(place);
nested_function->reset(nested_place(place));
}
bool has_trivial_destructor() const override {
return nested_function->has_trivial_destructor();
}
size_t size_of_data() const override { return prefix_size + nested_function->size_of_data(); }
size_t align_of_data() const override { return nested_function->align_of_data(); }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena* arena) const override {
if (result_is_nullable && get_flag(rhs)) {
set_flag(place);
}
nested_function->merge(nested_place(place), nested_place(rhs), arena);
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
bool flag = get_flag(place);
if (result_is_nullable) {
write_binary(flag, buf);
}
if (flag) {
nested_function->serialize(nested_place(place), buf);
}
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena* arena) const override {
bool flag = true;
if (result_is_nullable) {
read_binary(flag, buf);
}
if (flag) {
set_flag(place);
nested_function->deserialize(nested_place(place), buf, arena);
}
}
void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena* arena) const override {
bool flag = true;
if (result_is_nullable) {
read_binary(flag, buf);
}
if (flag) {
set_flag(place);
nested_function->deserialize_and_merge(nested_place(place), buf, arena);
}
}
void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column,
Arena* arena) const override {
size_t num_rows = column.size();
for (size_t i = 0; i != num_rows; ++i) {
VectorBufferReader buffer_reader(
(assert_cast<const ColumnString&>(column)).get_data_at(i));
deserialize_and_merge(place, buffer_reader, arena);
}
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
if constexpr (result_is_nullable) {
ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to);
if (get_flag(place)) {
nested_function->insert_result_into(nested_place(place),
to_concrete.get_nested_column());
to_concrete.get_null_map_data().push_back(0);
} else {
to_concrete.insert_default();
}
} else {
nested_function->insert_result_into(nested_place(place), to);
}
}
bool allocates_memory_in_arena() const override {
return nested_function->allocates_memory_in_arena();
}
bool is_state() const override { return nested_function->is_state(); }
};
/** There are two cases: for single argument and variadic.
* Code for single argument is much more efficient.
*/
template <typename NestFuction, bool result_is_nullable>
class AggregateFunctionNullUnaryInline final
: public AggregateFunctionNullBaseInline<
NestFuction, result_is_nullable,
AggregateFunctionNullUnaryInline<NestFuction, result_is_nullable>> {
public:
AggregateFunctionNullUnaryInline(IAggregateFunction* nested_function_,
const DataTypes& arguments, const Array& params)
: AggregateFunctionNullBaseInline<
NestFuction, result_is_nullable,
AggregateFunctionNullUnaryInline<NestFuction, result_is_nullable>>(
nested_function_, arguments, params) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena* arena) const override {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
if (!column->is_null_at(row_num)) {
this->set_flag(place);
const IColumn* nested_column = &column->get_nested_column();
this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena);
}
}
void add_not_nullable(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num, Arena* arena) const {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
this->set_flag(place);
const IColumn* nested_column = &column->get_nested_column();
this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena);
}
void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
const IColumn** columns, Arena* arena, bool agg_many) const override {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
// The overhead introduced is negligible here, just an extra memory read from NullMap
const auto* __restrict null_map_data = column->get_null_map_data().data();
const IColumn* nested_column = &column->get_nested_column();
for (int i = 0; i < batch_size; ++i) {
if (!null_map_data[i]) {
AggregateDataPtr __restrict place = places[i] + place_offset;
this->set_flag(place);
this->nested_function->add(this->nested_place(place), &nested_column, i, arena);
}
}
}
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena* arena) const override {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
bool has_null = column->has_null();
if (has_null) {
for (size_t i = 0; i < batch_size; ++i) {
if (!column->is_null_at(i)) {
this->set_flag(place);
this->add(place, columns, i, arena);
}
}
} else {
this->set_flag(place);
const IColumn* nested_column = &column->get_nested_column();
this->nested_function->add_batch_single_place(batch_size, this->nested_place(place),
&nested_column, arena);
}
}
void add_batch_range(size_t batch_begin, size_t batch_end, AggregateDataPtr place,
const IColumn** columns, Arena* arena, bool has_null) override {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
if (has_null) {
for (size_t i = batch_begin; i <= batch_end; ++i) {
if (!column->is_null_at(i)) {
this->set_flag(place);
this->add(place, columns, i, arena);
}
}
} else {
this->set_flag(place);
const IColumn* nested_column = &column->get_nested_column();
this->nested_function->add_batch_range(batch_begin, batch_end,
this->nested_place(place), &nested_column, arena,
false);
}
}
};
} // namespace doris::vectorized

View File

@ -25,6 +25,7 @@
#include "common/logging.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/data_types/data_type_nullable.h"
namespace doris::vectorized {
@ -45,15 +46,24 @@ AggregateFunctionPtr create_aggregate_function_sum(const std::string& name,
const DataTypes& argument_types,
const Array& parameters,
const bool result_is_nullable) {
// assert_no_parameters(name, parameters);
// assert_unary(name, argument_types);
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (is_decimal(data_type)) {
res.reset(create_with_decimal_type<Function>(*data_type, *data_type, argument_types));
if (data_type->is_nullable()) {
auto no_null_argument_types = remove_nullable(argument_types);
if (is_decimal(no_null_argument_types[0])) {
res.reset(create_with_decimal_type_null<Function>(no_null_argument_types, parameters,
*no_null_argument_types[0],
no_null_argument_types));
} else {
res.reset(create_with_numeric_type_null<Function>(no_null_argument_types, parameters,
no_null_argument_types));
}
} else {
res.reset(create_with_numeric_type<Function>(*data_type, argument_types));
if (is_decimal(data_type)) {
res.reset(create_with_decimal_type<Function>(*data_type, *data_type, argument_types));
} else {
res.reset(create_with_numeric_type<Function>(*data_type, argument_types));
}
}
if (!res) {
@ -84,6 +94,8 @@ AggregateFunctionPtr create_aggregate_function_sum_reader(const std::string& nam
void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) {
factory.register_function("sum", create_aggregate_function_sum<AggregateFunctionSumSimple>);
factory.register_function("sum", create_aggregate_function_sum<AggregateFunctionSumSimple>,
true);
}
} // namespace doris::vectorized

View File

@ -21,8 +21,10 @@
#pragma once
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_null.h"
#include "vec/data_types/data_type.h"
// TODO: Should we support decimal in numeric types?
#define FOR_NUMERIC_TYPES(M) \
M(UInt8) \
M(UInt16) \
@ -36,6 +38,12 @@
M(Float32) \
M(Float64)
#define FOR_DECIMAL_TYPES(M) \
M(Decimal32) \
M(Decimal64) \
M(Decimal128) \
M(Decimal128I)
namespace doris::vectorized {
/** Create an aggregate function with a numeric type in the template parameter, depending on the type of the argument.
@ -49,12 +57,20 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty
return new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8) {
return new AggregateFunctionTemplate<Int8>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Enum16) {
return new AggregateFunctionTemplate<Int16>(std::forward<TArgs>(args)...);
}
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction* create_with_numeric_type_null(const DataTypes& argument_types,
const Array& params, TArgs&&... args) {
WhichDataType which(argument_types[0]);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \
new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), argument_types, \
params);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
@ -68,12 +84,6 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty
return new AggregateFunctionTemplate<TYPE, bool_param>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8) {
return new AggregateFunctionTemplate<Int8, bool_param>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Enum16) {
return new AggregateFunctionTemplate<Int16, bool_param>(std::forward<TArgs>(args)...);
}
return nullptr;
}
@ -87,12 +97,6 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty
return new AggregateFunctionTemplate<TYPE, Data>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8) {
return new AggregateFunctionTemplate<Int8, Data>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Enum16) {
return new AggregateFunctionTemplate<Int16, Data>(std::forward<TArgs>(args)...);
}
return nullptr;
}
@ -106,12 +110,6 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty
return new AggregateFunctionTemplate<TYPE, Data<TYPE>>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8) {
return new AggregateFunctionTemplate<Int8, Data<Int8>>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Enum16) {
return new AggregateFunctionTemplate<Int16, Data<Int16>>(std::forward<TArgs>(args)...);
}
return nullptr;
}
@ -125,51 +123,6 @@ static IAggregateFunction* create_with_numeric_type(const IDataType& argument_ty
return new AggregateFunctionTemplate<Data<TYPE>>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
// if (which.idx == TypeIndex::Enum8) return new AggregateFunctionTemplate<Data<Int8>>(std::forward<TArgs>(args)...);
// if (which.idx == TypeIndex::Enum16) return new AggregateFunctionTemplate<Data<Int16>>(std::forward<TArgs>(args)...);
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate,
template <typename> class Data, typename... TArgs>
static IAggregateFunction* create_with_unsigned_integer_type(const IDataType& argument_type,
TArgs&&... args) {
WhichDataType which(argument_type);
if (which.idx == TypeIndex::UInt8) {
return new AggregateFunctionTemplate<UInt8, Data<UInt8>>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::UInt16) {
return new AggregateFunctionTemplate<UInt16, Data<UInt16>>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::UInt32) {
return new AggregateFunctionTemplate<UInt32, Data<UInt32>>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::UInt64) {
return new AggregateFunctionTemplate<UInt64, Data<UInt64>>(std::forward<TArgs>(args)...);
}
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction* create_with_numeric_based_type(const IDataType& argument_type,
TArgs&&... args) {
IAggregateFunction* f = create_with_numeric_type<AggregateFunctionTemplate>(
argument_type, std::forward<TArgs>(args)...);
if (f) {
return f;
}
/// expects that DataTypeDate based on UInt16, DataTypeDateTime based on UInt32 and UUID based on UInt128
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Date) {
return new AggregateFunctionTemplate<UInt16>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::DateTime) {
return new AggregateFunctionTemplate<UInt32>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::UUID) {
return new AggregateFunctionTemplate<UInt128>(std::forward<TArgs>(args)...);
}
return nullptr;
}
@ -177,18 +130,25 @@ template <template <typename> class AggregateFunctionTemplate, typename... TArgs
static IAggregateFunction* create_with_decimal_type(const IDataType& argument_type,
TArgs&&... args) {
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Decimal32) {
return new AggregateFunctionTemplate<Decimal32>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Decimal64) {
return new AggregateFunctionTemplate<Decimal64>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Decimal128) {
return new AggregateFunctionTemplate<Decimal128>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Decimal128I) {
return new AggregateFunctionTemplate<Decimal128I>(std::forward<TArgs>(args)...);
}
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction* create_with_decimal_type_null(const DataTypes& argument_types,
const Array& params, TArgs&&... args) {
WhichDataType which(argument_types[0]);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \
new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), argument_types, \
params);
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
@ -197,18 +157,11 @@ template <template <typename, typename> class AggregateFunctionTemplate, typenam
static IAggregateFunction* create_with_decimal_type(const IDataType& argument_type,
TArgs&&... args) {
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Decimal32) {
return new AggregateFunctionTemplate<Decimal32, Data>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Decimal64) {
return new AggregateFunctionTemplate<Decimal64, Data>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Decimal128) {
return new AggregateFunctionTemplate<Decimal128, Data>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Decimal128I) {
return new AggregateFunctionTemplate<Decimal128I, Data>(std::forward<TArgs>(args)...);
}
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<TYPE, Data>(std::forward<TArgs>(args)...);
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
@ -224,12 +177,6 @@ static IAggregateFunction* create_with_two_numeric_types_second(const IDataType&
return new AggregateFunctionTemplate<FirstType, TYPE>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8) {
return new AggregateFunctionTemplate<FirstType, Int8>(std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Enum16) {
return new AggregateFunctionTemplate<FirstType, Int16>(std::forward<TArgs>(args)...);
}
return nullptr;
}
@ -244,14 +191,6 @@ static IAggregateFunction* create_with_two_numeric_types(const IDataType& first_
second_type, std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8) {
return create_with_two_numeric_types_second<Int8, AggregateFunctionTemplate>(
second_type, std::forward<TArgs>(args)...);
}
if (which.idx == TypeIndex::Enum16) {
return create_with_two_numeric_types_second<Int16, AggregateFunctionTemplate>(
second_type, std::forward<TArgs>(args)...);
}
return nullptr;
}

View File

@ -158,4 +158,16 @@ DataTypePtr remove_nullable(const DataTypePtr& type) {
return type;
}
DataTypes remove_nullable(const DataTypes& types) {
DataTypes no_null_types;
for (auto& type : types) {
if (type->is_nullable()) {
no_null_types.push_back(static_cast<const DataTypeNullable&>(*type).get_nested_type());
} else {
no_null_types.push_back(type);
}
}
return no_null_types;
}
} // namespace doris::vectorized

View File

@ -93,5 +93,6 @@ private:
DataTypePtr make_nullable(const DataTypePtr& type);
DataTypePtr remove_nullable(const DataTypePtr& type);
DataTypes remove_nullable(const DataTypes& types);
} // namespace doris::vectorized