[Enchancement](function) refact and optimize some function register (#16955)
refact and optimize some function register
This commit is contained in:
@ -41,25 +41,8 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name,
|
||||
const bool result_is_nullable) {
|
||||
assert_unary(name, argument_types);
|
||||
|
||||
AggregateFunctionPtr res;
|
||||
DataTypePtr data_type = argument_types[0];
|
||||
if (data_type->is_nullable()) {
|
||||
auto no_null_argument_types = remove_nullable(argument_types);
|
||||
if (is_decimal(no_null_argument_types[0])) {
|
||||
res.reset(create_with_decimal_type_null<AggregateFuncAvg>(
|
||||
no_null_argument_types, *no_null_argument_types[0], no_null_argument_types));
|
||||
} else {
|
||||
res.reset(create_with_numeric_type_null<AggregateFuncAvg>(no_null_argument_types,
|
||||
no_null_argument_types));
|
||||
}
|
||||
} else {
|
||||
if (is_decimal(data_type)) {
|
||||
res.reset(create_with_decimal_type<AggregateFuncAvg>(*data_type, *data_type,
|
||||
argument_types));
|
||||
} else {
|
||||
res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, argument_types));
|
||||
}
|
||||
}
|
||||
AggregateFunctionPtr res(
|
||||
creator_with_type::create<AggregateFuncAvg>(result_is_nullable, argument_types));
|
||||
|
||||
if (!res) {
|
||||
LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate function {}",
|
||||
|
||||
@ -95,12 +95,7 @@ public:
|
||||
/// ctor for native types
|
||||
AggregateFunctionAvg(const DataTypes& argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_),
|
||||
scale(0) {}
|
||||
|
||||
/// ctor for Decimals
|
||||
AggregateFunctionAvg(const IDataType& data_type, const DataTypes& argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_),
|
||||
scale(get_decimal_scale(data_type)) {}
|
||||
scale(get_decimal_scale(*argument_types_[0])) {}
|
||||
|
||||
String get_name() const override { return "avg"; }
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "vec/aggregate_functions/aggregate_function_bit.h"
|
||||
|
||||
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
|
||||
#include "vec/aggregate_functions/helpers.h"
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
@ -34,26 +35,10 @@ AggregateFunctionPtr createAggregateFunctionBitwise(const std::string& name,
|
||||
" is illegal, because it cannot be used in bitwise operations");
|
||||
}
|
||||
|
||||
auto type = argument_types[0].get();
|
||||
if (type->is_nullable()) {
|
||||
type = assert_cast<const DataTypeNullable*>(type)->get_nested_type().get();
|
||||
}
|
||||
|
||||
WhichDataType which(*type);
|
||||
if (which.is_int8()) {
|
||||
return AggregateFunctionPtr(new AggregateFunctionBitwise<Int8, Data<Int8>>(argument_types));
|
||||
} else if (which.is_int16()) {
|
||||
return AggregateFunctionPtr(
|
||||
new AggregateFunctionBitwise<Int16, Data<Int16>>(argument_types));
|
||||
} else if (which.is_int32()) {
|
||||
return AggregateFunctionPtr(
|
||||
new AggregateFunctionBitwise<Int32, Data<Int32>>(argument_types));
|
||||
} else if (which.is_int64()) {
|
||||
return AggregateFunctionPtr(
|
||||
new AggregateFunctionBitwise<Int64, Data<Int64>>(argument_types));
|
||||
} else if (which.is_int128()) {
|
||||
return AggregateFunctionPtr(
|
||||
new AggregateFunctionBitwise<Int128, Data<Int128>>(argument_types));
|
||||
AggregateFunctionPtr res(creator_with_integer_type::create<AggregateFunctionBitwise, Data>(
|
||||
result_is_nullable, argument_types));
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
LOG(WARNING) << fmt::format("Illegal type " + argument_types[0]->get_name() +
|
||||
@ -68,6 +53,15 @@ void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory) {
|
||||
createAggregateFunctionBitwise<AggregateFunctionGroupBitAndData>);
|
||||
factory.register_function("group_bit_xor",
|
||||
createAggregateFunctionBitwise<AggregateFunctionGroupBitXorData>);
|
||||
|
||||
factory.register_function(
|
||||
"group_bit_or", createAggregateFunctionBitwise<AggregateFunctionGroupBitOrData>, true);
|
||||
factory.register_function("group_bit_and",
|
||||
createAggregateFunctionBitwise<AggregateFunctionGroupBitAndData>,
|
||||
true);
|
||||
factory.register_function("group_bit_xor",
|
||||
createAggregateFunctionBitwise<AggregateFunctionGroupBitXorData>,
|
||||
true);
|
||||
}
|
||||
|
||||
} // namespace doris::vectorized
|
||||
@ -49,30 +49,30 @@ public:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AggregateFunctionPtr res;
|
||||
if (arguments.size() == 1) {
|
||||
res.reset(create_with_numeric_type<AggregateFunctionDistinct,
|
||||
AggregateFunctionDistinctSingleNumericData>(
|
||||
*arguments[0], nested_function, arguments));
|
||||
|
||||
AggregateFunctionPtr res(
|
||||
creator_with_numeric_type::create<AggregateFunctionDistinct,
|
||||
AggregateFunctionDistinctSingleNumericData>(
|
||||
result_is_nullable, arguments, nested_function));
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
if (arguments[0]->is_value_unambiguously_represented_in_contiguous_memory_region()) {
|
||||
return std::make_shared<AggregateFunctionDistinct<
|
||||
AggregateFunctionDistinctSingleGenericData<true>>>(nested_function,
|
||||
arguments);
|
||||
res.reset(creator_without_type::create<AggregateFunctionDistinct<
|
||||
AggregateFunctionDistinctSingleGenericData<true>>>(
|
||||
result_is_nullable, arguments, nested_function));
|
||||
} else {
|
||||
return std::make_shared<AggregateFunctionDistinct<
|
||||
AggregateFunctionDistinctSingleGenericData<false>>>(nested_function,
|
||||
arguments);
|
||||
res.reset(creator_without_type::create<AggregateFunctionDistinct<
|
||||
AggregateFunctionDistinctSingleGenericData<false>>>(
|
||||
result_is_nullable, arguments, nested_function));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
return std::make_shared<
|
||||
AggregateFunctionDistinct<AggregateFunctionDistinctMultipleGenericData>>(
|
||||
nested_function, arguments);
|
||||
return AggregateFunctionPtr(
|
||||
creator_without_type::create<
|
||||
AggregateFunctionDistinct<AggregateFunctionDistinctMultipleGenericData>>(
|
||||
result_is_nullable, arguments, nested_function));
|
||||
}
|
||||
};
|
||||
|
||||
@ -93,5 +93,6 @@ void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact
|
||||
result_is_nullable);
|
||||
};
|
||||
factory.register_distinct_function_combinator(creator, DISTINCT_FUNCTION_PREFIX);
|
||||
factory.register_distinct_function_combinator(creator, DISTINCT_FUNCTION_PREFIX, true);
|
||||
}
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -26,50 +26,45 @@
|
||||
|
||||
namespace doris::vectorized {
|
||||
/// min, max, any
|
||||
template <template <typename, bool> class AggregateFunctionTemplate, template <typename> class Data>
|
||||
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data>
|
||||
static IAggregateFunction* create_aggregate_function_single_value(const String& name,
|
||||
const DataTypes& argument_types) {
|
||||
const DataTypes& argument_types,
|
||||
const bool result_is_nullable) {
|
||||
assert_unary(name, argument_types);
|
||||
|
||||
const DataTypePtr& argument_type = argument_types[0];
|
||||
|
||||
IAggregateFunction* res(creator_with_numeric_type::create<AggregateFunctionTemplate, Data,
|
||||
SingleValueDataFixed>(
|
||||
result_is_nullable, argument_types));
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
res = creator_with_decimal_type::create<AggregateFunctionTemplate, Data,
|
||||
SingleValueDataDecimal>(result_is_nullable,
|
||||
argument_types);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
|
||||
WhichDataType which(argument_type);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE>>, false>( \
|
||||
argument_type);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
|
||||
if (which.idx == TypeIndex::String) {
|
||||
return new AggregateFunctionTemplate<Data<SingleValueDataString>, false>(argument_type);
|
||||
return creator_without_type::create<AggregateFunctionTemplate<Data<SingleValueDataString>>>(
|
||||
result_is_nullable, argument_types);
|
||||
}
|
||||
if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) {
|
||||
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Int64>>, false>(
|
||||
argument_type);
|
||||
return creator_without_type::create<
|
||||
AggregateFunctionTemplate<Data<SingleValueDataFixed<Int64>>>>(result_is_nullable,
|
||||
argument_types);
|
||||
}
|
||||
if (which.idx == TypeIndex::DateV2) {
|
||||
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt32>>, false>(
|
||||
argument_type);
|
||||
return creator_without_type::create<
|
||||
AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt32>>>>(result_is_nullable,
|
||||
argument_types);
|
||||
}
|
||||
if (which.idx == TypeIndex::DateTimeV2) {
|
||||
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt64>>, false>(
|
||||
argument_type);
|
||||
}
|
||||
if (which.idx == TypeIndex::Decimal32) {
|
||||
return new AggregateFunctionTemplate<Data<SingleValueDataDecimal<Decimal32>>, false>(
|
||||
argument_type);
|
||||
}
|
||||
if (which.idx == TypeIndex::Decimal64) {
|
||||
return new AggregateFunctionTemplate<Data<SingleValueDataDecimal<Decimal64>>, false>(
|
||||
argument_type);
|
||||
}
|
||||
if (which.idx == TypeIndex::Decimal128) {
|
||||
return new AggregateFunctionTemplate<Data<SingleValueDataDecimal<Decimal128>>, false>(
|
||||
argument_type);
|
||||
}
|
||||
if (which.idx == TypeIndex::Decimal128I) {
|
||||
return new AggregateFunctionTemplate<Data<SingleValueDataDecimal<Decimal128I>>, false>(
|
||||
argument_type);
|
||||
return creator_without_type::create<
|
||||
AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt64>>>>(result_is_nullable,
|
||||
argument_types);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -79,7 +74,8 @@ AggregateFunctionPtr create_aggregate_function_max(const std::string& name,
|
||||
const bool result_is_nullable) {
|
||||
return AggregateFunctionPtr(
|
||||
create_aggregate_function_single_value<AggregateFunctionsSingleValue,
|
||||
AggregateFunctionMaxData>(name, argument_types));
|
||||
AggregateFunctionMaxData>(name, argument_types,
|
||||
result_is_nullable));
|
||||
}
|
||||
|
||||
AggregateFunctionPtr create_aggregate_function_min(const std::string& name,
|
||||
@ -87,7 +83,8 @@ AggregateFunctionPtr create_aggregate_function_min(const std::string& name,
|
||||
const bool result_is_nullable) {
|
||||
return AggregateFunctionPtr(
|
||||
create_aggregate_function_single_value<AggregateFunctionsSingleValue,
|
||||
AggregateFunctionMinData>(name, argument_types));
|
||||
AggregateFunctionMinData>(name, argument_types,
|
||||
result_is_nullable));
|
||||
}
|
||||
|
||||
AggregateFunctionPtr create_aggregate_function_any(const std::string& name,
|
||||
@ -95,7 +92,8 @@ AggregateFunctionPtr create_aggregate_function_any(const std::string& name,
|
||||
const bool result_is_nullable) {
|
||||
return AggregateFunctionPtr(
|
||||
create_aggregate_function_single_value<AggregateFunctionsSingleValue,
|
||||
AggregateFunctionAnyData>(name, argument_types));
|
||||
AggregateFunctionAnyData>(name, argument_types,
|
||||
result_is_nullable));
|
||||
}
|
||||
|
||||
void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory) {
|
||||
@ -103,6 +101,10 @@ void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory)
|
||||
factory.register_function("min", create_aggregate_function_min);
|
||||
factory.register_function("any", create_aggregate_function_any);
|
||||
|
||||
factory.register_function("max", create_aggregate_function_max, true);
|
||||
factory.register_function("min", create_aggregate_function_min, true);
|
||||
factory.register_function("any", create_aggregate_function_any, true);
|
||||
|
||||
factory.register_alias("any", "any_value");
|
||||
}
|
||||
|
||||
|
||||
@ -485,19 +485,16 @@ struct AggregateFunctionAnyData : Data {
|
||||
static const char* name() { return "any"; }
|
||||
};
|
||||
|
||||
template <typename Data, bool AllocatesMemoryInArena>
|
||||
template <typename Data>
|
||||
class AggregateFunctionsSingleValue final
|
||||
: public IAggregateFunctionDataHelper<
|
||||
Data, AggregateFunctionsSingleValue<Data, AllocatesMemoryInArena>> {
|
||||
: public IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>> {
|
||||
private:
|
||||
DataTypePtr& type;
|
||||
using Base = IAggregateFunctionDataHelper<
|
||||
Data, AggregateFunctionsSingleValue<Data, AllocatesMemoryInArena>>;
|
||||
using Base = IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>;
|
||||
|
||||
public:
|
||||
AggregateFunctionsSingleValue(const DataTypePtr& type_)
|
||||
: IAggregateFunctionDataHelper<
|
||||
Data, AggregateFunctionsSingleValue<Data, AllocatesMemoryInArena>>({type_}),
|
||||
AggregateFunctionsSingleValue(const DataTypes& arguments)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionsSingleValue<Data>>(arguments),
|
||||
type(this->argument_types[0]) {
|
||||
if (StringRef(Data::name()) == StringRef("min") ||
|
||||
StringRef(Data::name()) == StringRef("max")) {
|
||||
@ -535,8 +532,6 @@ public:
|
||||
this->data(place).read(buf);
|
||||
}
|
||||
|
||||
bool allocates_memory_in_arena() const override { return AllocatesMemoryInArena; }
|
||||
|
||||
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
|
||||
this->data(place).insert_result_into(to);
|
||||
}
|
||||
|
||||
@ -310,10 +310,7 @@ public:
|
||||
|
||||
if (has_null) {
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
if (!column->is_null_at(i)) {
|
||||
this->set_flag(place);
|
||||
this->add(place, columns, i, arena);
|
||||
}
|
||||
this->add(place, columns, i, arena);
|
||||
}
|
||||
} else {
|
||||
this->set_flag(place);
|
||||
@ -329,10 +326,7 @@ public:
|
||||
|
||||
if (has_null) {
|
||||
for (size_t i = batch_begin; i <= batch_end; ++i) {
|
||||
if (!column->is_null_at(i)) {
|
||||
this->set_flag(place);
|
||||
this->add(place, columns, i, arena);
|
||||
}
|
||||
this->add(place, columns, i, arena);
|
||||
}
|
||||
} else {
|
||||
this->set_flag(place);
|
||||
@ -587,14 +581,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void add_not_nullable(AggregateDataPtr __restrict place, const IColumn** columns,
|
||||
size_t row_num, Arena* arena) const {
|
||||
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
|
||||
this->set_flag(place);
|
||||
const IColumn* nested_column = &column->get_nested_column();
|
||||
this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena);
|
||||
}
|
||||
|
||||
void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
|
||||
const IColumn** columns, Arena* arena, bool agg_many) const override {
|
||||
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
|
||||
@ -617,10 +603,7 @@ public:
|
||||
|
||||
if (has_null) {
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
if (!column->is_null_at(i)) {
|
||||
this->set_flag(place);
|
||||
this->add(place, columns, i, arena);
|
||||
}
|
||||
this->add(place, columns, i, arena);
|
||||
}
|
||||
} else {
|
||||
this->set_flag(place);
|
||||
@ -636,10 +619,7 @@ public:
|
||||
|
||||
if (has_null) {
|
||||
for (size_t i = batch_begin; i <= batch_end; ++i) {
|
||||
if (!column->is_null_at(i)) {
|
||||
this->set_flag(place);
|
||||
this->add(place, columns, i, arena);
|
||||
}
|
||||
this->add(place, columns, i, arena);
|
||||
}
|
||||
} else {
|
||||
this->set_flag(place);
|
||||
@ -650,4 +630,70 @@ public:
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename NestFuction, bool result_is_nullable>
|
||||
class AggregateFunctionNullVariadicInline final
|
||||
: public AggregateFunctionNullBaseInline<
|
||||
NestFuction, result_is_nullable,
|
||||
AggregateFunctionNullVariadicInline<NestFuction, result_is_nullable>> {
|
||||
public:
|
||||
AggregateFunctionNullVariadicInline(IAggregateFunction* nested_function_,
|
||||
const DataTypes& arguments)
|
||||
: AggregateFunctionNullBaseInline<
|
||||
NestFuction, result_is_nullable,
|
||||
AggregateFunctionNullVariadicInline<NestFuction, result_is_nullable>>(
|
||||
nested_function_, arguments),
|
||||
number_of_arguments(arguments.size()) {
|
||||
if (number_of_arguments == 1) {
|
||||
LOG(FATAL)
|
||||
<< "Logical error: single argument is passed to AggregateFunctionNullVariadic";
|
||||
}
|
||||
|
||||
if (number_of_arguments > MAX_ARGS) {
|
||||
LOG(FATAL) << fmt::format(
|
||||
"Maximum number of arguments for aggregate function with Nullable types is {}",
|
||||
size_t(MAX_ARGS));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < number_of_arguments; ++i) {
|
||||
is_nullable[i] = arguments[i]->is_nullable();
|
||||
}
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
|
||||
Arena* arena) const override {
|
||||
/// This container stores the columns we really pass to the nested function.
|
||||
const IColumn* nested_columns[number_of_arguments];
|
||||
|
||||
for (size_t i = 0; i < number_of_arguments; ++i) {
|
||||
if (is_nullable[i]) {
|
||||
const ColumnNullable& nullable_col =
|
||||
assert_cast<const ColumnNullable&>(*columns[i]);
|
||||
if (nullable_col.is_null_at(row_num)) {
|
||||
/// If at least one column has a null value in the current row,
|
||||
/// we don't process this row.
|
||||
return;
|
||||
}
|
||||
nested_columns[i] = &nullable_col.get_nested_column();
|
||||
} else {
|
||||
nested_columns[i] = columns[i];
|
||||
}
|
||||
}
|
||||
|
||||
this->set_flag(place);
|
||||
this->nested_function->add(this->nested_place(place), nested_columns, row_num, arena);
|
||||
}
|
||||
|
||||
bool allocates_memory_in_arena() const override {
|
||||
return this->nested_function->allocates_memory_in_arena();
|
||||
}
|
||||
|
||||
private:
|
||||
// The array length is fixed in the implementation of some aggregate functions.
|
||||
// Therefore we choose 256 as the appropriate maximum length limit.
|
||||
static const size_t MAX_ARGS = 256;
|
||||
size_t number_of_arguments = 0;
|
||||
std::array<char, MAX_ARGS>
|
||||
is_nullable; /// Plain array is better than std::vector due to one indirection less.
|
||||
};
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -36,18 +36,22 @@ AggregateFunctionPtr create_aggregate_function_orthogonal(const std::string& nam
|
||||
} else if (argument_types.size() == 1) {
|
||||
return std::make_shared<AggFunctionOrthBitmapFunc<Impl<StringRef>>>(argument_types);
|
||||
} else {
|
||||
const IDataType& argument_type = *argument_types[1];
|
||||
AggregateFunctionPtr res(create_with_numeric_type<AggFunctionOrthBitmapFunc, Impl>(
|
||||
argument_type, argument_types));
|
||||
|
||||
WhichDataType which(argument_type);
|
||||
WhichDataType which(*remove_nullable(argument_types[1]));
|
||||
|
||||
AggregateFunctionPtr res(
|
||||
creator_with_type_base<true, true, false, 1>::create<AggFunctionOrthBitmapFunc,
|
||||
Impl>(result_is_nullable,
|
||||
argument_types));
|
||||
if (res) {
|
||||
return res;
|
||||
} else if (which.is_string_or_fixed_string()) {
|
||||
return std::make_shared<AggFunctionOrthBitmapFunc<Impl<std::string_view>>>(
|
||||
argument_types);
|
||||
res.reset(
|
||||
creator_without_type::create<AggFunctionOrthBitmapFunc<Impl<std::string_view>>>(
|
||||
result_is_nullable, argument_types));
|
||||
return res;
|
||||
}
|
||||
|
||||
const IDataType& argument_type = *argument_types[1];
|
||||
LOG(WARNING) << "Incorrect Type " << argument_type.get_name()
|
||||
<< " of arguments for aggregate function " << name;
|
||||
return nullptr;
|
||||
@ -91,5 +95,16 @@ void register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactor
|
||||
create_aggregate_function_orthogonal_bitmap_union_count);
|
||||
|
||||
factory.register_function("intersect_count", create_aggregate_function_intersect_count);
|
||||
|
||||
factory.register_function("orthogonal_bitmap_intersect",
|
||||
create_aggregate_function_orthogonal_bitmap_intersect, true);
|
||||
|
||||
factory.register_function("orthogonal_bitmap_intersect_count",
|
||||
create_aggregate_function_orthogonal_bitmap_intersect_count, true);
|
||||
|
||||
factory.register_function("orthogonal_bitmap_union_count",
|
||||
create_aggregate_function_orthogonal_bitmap_union_count, true);
|
||||
|
||||
factory.register_function("intersect_count", create_aggregate_function_intersect_count, true);
|
||||
}
|
||||
} // namespace doris::vectorized
|
||||
@ -92,12 +92,7 @@ public:
|
||||
AggregateFunctionProduct(const DataTypes& argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionProduct<T, TResult, Data>>(
|
||||
argument_types_),
|
||||
scale(0) {}
|
||||
|
||||
AggregateFunctionProduct(const IDataType& data_type, const DataTypes& argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionProduct<T, TResult, Data>>(
|
||||
argument_types_),
|
||||
scale(get_decimal_scale(data_type)) {}
|
||||
scale(get_decimal_scale(*argument_types_[0])) {}
|
||||
|
||||
DataTypePtr get_return_type() const override {
|
||||
if constexpr (IsDecimalNumber<T>) {
|
||||
|
||||
@ -45,25 +45,8 @@ template <template <typename> class Function>
|
||||
AggregateFunctionPtr create_aggregate_function_sum(const std::string& name,
|
||||
const DataTypes& argument_types,
|
||||
const bool result_is_nullable) {
|
||||
AggregateFunctionPtr res;
|
||||
DataTypePtr data_type = argument_types[0];
|
||||
if (data_type->is_nullable()) {
|
||||
auto no_null_argument_types = remove_nullable(argument_types);
|
||||
if (is_decimal(no_null_argument_types[0])) {
|
||||
res.reset(create_with_decimal_type_null<Function>(
|
||||
no_null_argument_types, *no_null_argument_types[0], no_null_argument_types));
|
||||
} else {
|
||||
res.reset(create_with_numeric_type_null<Function>(no_null_argument_types,
|
||||
no_null_argument_types));
|
||||
}
|
||||
} else {
|
||||
if (is_decimal(data_type)) {
|
||||
res.reset(create_with_decimal_type<Function>(*data_type, *data_type, argument_types));
|
||||
} else {
|
||||
res.reset(create_with_numeric_type<Function>(*data_type, argument_types));
|
||||
}
|
||||
}
|
||||
|
||||
AggregateFunctionPtr res(
|
||||
creator_with_type::create<Function>(result_is_nullable, argument_types));
|
||||
if (!res) {
|
||||
LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate function {}",
|
||||
argument_types[0]->get_name(), name);
|
||||
|
||||
@ -59,12 +59,7 @@ public:
|
||||
AggregateFunctionSum(const DataTypes& argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(
|
||||
argument_types_),
|
||||
scale(0) {}
|
||||
|
||||
AggregateFunctionSum(const IDataType& data_type, const DataTypes& argument_types_)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(
|
||||
argument_types_),
|
||||
scale(get_decimal_scale(data_type)) {}
|
||||
scale(get_decimal_scale(*argument_types_[0])) {}
|
||||
|
||||
DataTypePtr get_return_type() const override {
|
||||
if constexpr (IsDecimalNumber<T>) {
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
template <template <typename> class Data, typename DataForVariadic>
|
||||
template <template <typename> class Data>
|
||||
AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name,
|
||||
const DataTypes& argument_types,
|
||||
const bool result_is_nullable) {
|
||||
@ -38,27 +38,29 @@ AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name,
|
||||
}
|
||||
|
||||
if (argument_types.size() == 1) {
|
||||
const IDataType& argument_type = *argument_types[0];
|
||||
|
||||
AggregateFunctionPtr res(create_with_numeric_type<AggregateFunctionUniq, Data>(
|
||||
*argument_types[0], argument_types));
|
||||
|
||||
const IDataType& argument_type = *remove_nullable(argument_types[0]);
|
||||
WhichDataType which(argument_type);
|
||||
// TODO: DateType
|
||||
|
||||
AggregateFunctionPtr res(creator_with_numeric_type::create<AggregateFunctionUniq, Data>(
|
||||
result_is_nullable, argument_types));
|
||||
if (res) {
|
||||
return res;
|
||||
} else if (which.is_decimal32()) {
|
||||
return std::make_shared<AggregateFunctionUniq<Decimal32, Data<Int32>>>(argument_types);
|
||||
return AggregateFunctionPtr(
|
||||
creator_without_type::create<AggregateFunctionUniq<Decimal32, Data<Int32>>>(
|
||||
result_is_nullable, argument_types));
|
||||
} else if (which.is_decimal64()) {
|
||||
return std::make_shared<AggregateFunctionUniq<Decimal64, Data<Int64>>>(argument_types);
|
||||
} else if (which.is_decimal128()) {
|
||||
return std::make_shared<AggregateFunctionUniq<Decimal128, Data<Int128>>>(
|
||||
argument_types);
|
||||
} else if (which.is_decimal128i()) {
|
||||
return std::make_shared<AggregateFunctionUniq<Decimal128, Data<Int128>>>(
|
||||
argument_types);
|
||||
return AggregateFunctionPtr(
|
||||
creator_without_type::create<AggregateFunctionUniq<Decimal64, Data<Int64>>>(
|
||||
result_is_nullable, argument_types));
|
||||
} else if (which.is_decimal128() || which.is_decimal128i()) {
|
||||
return AggregateFunctionPtr(
|
||||
creator_without_type::create<AggregateFunctionUniq<Decimal128, Data<Int128>>>(
|
||||
result_is_nullable, argument_types));
|
||||
} else if (which.is_string_or_fixed_string()) {
|
||||
return std::make_shared<AggregateFunctionUniq<String, Data<String>>>(argument_types);
|
||||
return AggregateFunctionPtr(
|
||||
creator_without_type::create<AggregateFunctionUniq<String, Data<String>>>(
|
||||
result_is_nullable, argument_types));
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,9 +69,9 @@ AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name,
|
||||
|
||||
void register_aggregate_function_uniq(AggregateFunctionSimpleFactory& factory) {
|
||||
AggregateFunctionCreator creator =
|
||||
create_aggregate_function_uniq<AggregateFunctionUniqExactData,
|
||||
AggregateFunctionUniqExactData<String>>;
|
||||
create_aggregate_function_uniq<AggregateFunctionUniqExactData>;
|
||||
factory.register_function("multi_distinct_count", creator);
|
||||
factory.register_function("multi_distinct_count", creator, true);
|
||||
}
|
||||
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -23,9 +23,10 @@
|
||||
#include "vec/aggregate_functions/aggregate_function.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_null.h"
|
||||
#include "vec/data_types/data_type.h"
|
||||
#include "vec/utils/template_helpers.hpp"
|
||||
|
||||
// TODO: Should we support decimal in numeric types?
|
||||
#define FOR_NUMERIC_TYPES(M) \
|
||||
#define FOR_INTEGER_TYPES(M) \
|
||||
M(UInt8) \
|
||||
M(UInt16) \
|
||||
M(UInt32) \
|
||||
@ -34,10 +35,16 @@
|
||||
M(Int16) \
|
||||
M(Int32) \
|
||||
M(Int64) \
|
||||
M(Int128) \
|
||||
M(Float32) \
|
||||
M(Int128)
|
||||
|
||||
#define FOR_FLOAT_TYPES(M) \
|
||||
M(Float32) \
|
||||
M(Float64)
|
||||
|
||||
#define FOR_NUMERIC_TYPES(M) \
|
||||
FOR_INTEGER_TYPES(M) \
|
||||
FOR_FLOAT_TYPES(M)
|
||||
|
||||
#define FOR_DECIMAL_TYPES(M) \
|
||||
M(Decimal32) \
|
||||
M(Decimal64) \
|
||||
@ -48,150 +55,162 @@ namespace doris::vectorized {
|
||||
|
||||
/** Create an aggregate function with a numeric type in the template parameter, depending on the type of the argument.
|
||||
*/
|
||||
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
|
||||
static IAggregateFunction* create_with_numeric_type(const IDataType& argument_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_type);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
|
||||
static IAggregateFunction* create_with_numeric_type_null(const DataTypes& argument_types,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_types[0]);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \
|
||||
new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), \
|
||||
argument_types);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename, bool> class AggregateFunctionTemplate, bool bool_param,
|
||||
typename... TArgs>
|
||||
static IAggregateFunction* create_with_numeric_type(const IDataType& argument_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_type);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionTemplate<TYPE, bool_param>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename, typename> class AggregateFunctionTemplate, typename Data,
|
||||
typename... TArgs>
|
||||
static IAggregateFunction* create_with_numeric_type(const IDataType& argument_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_type);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionTemplate<TYPE, Data>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename, typename> class AggregateFunctionTemplate,
|
||||
template <typename> class Data, typename... TArgs>
|
||||
static IAggregateFunction* create_with_numeric_type(const IDataType& argument_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_type);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionTemplate<TYPE, Data<TYPE>>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, typename Type>
|
||||
struct BuilderDirect {
|
||||
using T = AggregateFunctionTemplate<Type>;
|
||||
};
|
||||
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
|
||||
typename... TArgs>
|
||||
static IAggregateFunction* create_with_numeric_type(const IDataType& argument_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_type);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionTemplate<Data<TYPE>>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
typename Type>
|
||||
struct BuilderData {
|
||||
using T = AggregateFunctionTemplate<Data<Type>>;
|
||||
};
|
||||
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
|
||||
template <typename> class Impl, typename Type>
|
||||
struct BuilderDataImpl {
|
||||
using T = AggregateFunctionTemplate<Data<Impl<Type>>>;
|
||||
};
|
||||
template <template <typename, typename> class AggregateFunctionTemplate,
|
||||
template <typename> class Data, typename Type>
|
||||
struct BuilderDirectAndData {
|
||||
using T = AggregateFunctionTemplate<Type, Data<Type>>;
|
||||
};
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
|
||||
static IAggregateFunction* create_with_decimal_type(const IDataType& argument_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_type);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
|
||||
FOR_DECIMAL_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
template <template <typename> class AggregateFunctionTemplate>
|
||||
struct CurryDirect {
|
||||
template <typename Type>
|
||||
using Builder = BuilderDirect<AggregateFunctionTemplate, Type>;
|
||||
};
|
||||
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data>
|
||||
struct CurryData {
|
||||
template <typename Type>
|
||||
using Builder = BuilderData<AggregateFunctionTemplate, Data, Type>;
|
||||
};
|
||||
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
|
||||
template <typename> class Impl>
|
||||
struct CurryDataImpl {
|
||||
template <typename Type>
|
||||
using Builder = BuilderDataImpl<AggregateFunctionTemplate, Data, Impl, Type>;
|
||||
};
|
||||
template <template <typename, typename> class AggregateFunctionTemplate,
|
||||
template <typename> class Data>
|
||||
struct CurryDirectAndData {
|
||||
template <typename Type>
|
||||
using Builder = BuilderDirectAndData<AggregateFunctionTemplate, Data, Type>;
|
||||
};
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
|
||||
static IAggregateFunction* create_with_decimal_type_null(const DataTypes& argument_types,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_types[0]);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \
|
||||
new AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), \
|
||||
argument_types);
|
||||
FOR_DECIMAL_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
template <bool allow_integer, bool allow_float, bool allow_decimal, int define_index = 0>
|
||||
struct creator_with_type_base {
|
||||
template <typename Class, typename... TArgs>
|
||||
static IAggregateFunction* create_base(const bool result_is_nullable,
|
||||
const DataTypes& argument_types, TArgs&&... args) {
|
||||
WhichDataType which(remove_nullable(argument_types[define_index]));
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) { \
|
||||
using T = typename Class::template Builder<TYPE>::T; \
|
||||
if (have_nullable(argument_types)) { \
|
||||
IAggregateFunction* result = nullptr; \
|
||||
if (argument_types.size() > 1) { \
|
||||
std::visit( \
|
||||
[&](auto result_is_nullable) { \
|
||||
result = new AggregateFunctionNullVariadicInline<T, \
|
||||
result_is_nullable>( \
|
||||
new T(std::forward<TArgs>(args)..., \
|
||||
remove_nullable(argument_types)), \
|
||||
argument_types); \
|
||||
}, \
|
||||
make_bool_variant(result_is_nullable)); \
|
||||
} else { \
|
||||
std::visit( \
|
||||
[&](auto result_is_nullable) { \
|
||||
result = new AggregateFunctionNullUnaryInline<T, result_is_nullable>( \
|
||||
new T(std::forward<TArgs>(args)..., \
|
||||
remove_nullable(argument_types)), \
|
||||
argument_types); \
|
||||
}, \
|
||||
make_bool_variant(result_is_nullable)); \
|
||||
} \
|
||||
return result; \
|
||||
} else { \
|
||||
return new T(std::forward<TArgs>(args)..., argument_types); \
|
||||
} \
|
||||
}
|
||||
|
||||
template <template <typename, typename> class AggregateFunctionTemplate, typename Data,
|
||||
typename... TArgs>
|
||||
static IAggregateFunction* create_with_decimal_type(const IDataType& argument_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(argument_type);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionTemplate<TYPE, Data>(std::forward<TArgs>(args)...);
|
||||
FOR_DECIMAL_TYPES(DISPATCH)
|
||||
if constexpr (allow_integer) {
|
||||
FOR_INTEGER_TYPES(DISPATCH);
|
||||
}
|
||||
if constexpr (allow_float) {
|
||||
FOR_FLOAT_TYPES(DISPATCH);
|
||||
}
|
||||
if constexpr (allow_decimal) {
|
||||
FOR_DECIMAL_TYPES(DISPATCH);
|
||||
}
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/** For template with two arguments.
|
||||
*/
|
||||
template <typename FirstType, template <typename, typename> class AggregateFunctionTemplate,
|
||||
typename... TArgs>
|
||||
static IAggregateFunction* create_with_two_numeric_types_second(const IDataType& second_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(second_type);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return new AggregateFunctionTemplate<FirstType, TYPE>(std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
|
||||
static IAggregateFunction* create(TArgs&&... args) {
|
||||
return create_base<CurryDirect<AggregateFunctionTemplate>>(std::forward<TArgs>(args)...);
|
||||
}
|
||||
|
||||
template <template <typename, typename> class AggregateFunctionTemplate, typename... TArgs>
|
||||
static IAggregateFunction* create_with_two_numeric_types(const IDataType& first_type,
|
||||
const IDataType& second_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(first_type);
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return create_with_two_numeric_types_second<TYPE, AggregateFunctionTemplate>( \
|
||||
second_type, std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
|
||||
typename... TArgs>
|
||||
static IAggregateFunction* create(TArgs&&... args) {
|
||||
return create_base<CurryData<AggregateFunctionTemplate, Data>>(
|
||||
std::forward<TArgs>(args)...);
|
||||
}
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
|
||||
template <typename> class Impl, typename... TArgs>
|
||||
static IAggregateFunction* create(TArgs&&... args) {
|
||||
return create_base<CurryDataImpl<AggregateFunctionTemplate, Data, Impl>>(
|
||||
std::forward<TArgs>(args)...);
|
||||
}
|
||||
|
||||
template <template <typename, typename> class AggregateFunctionTemplate,
|
||||
template <typename> class Data, typename... TArgs>
|
||||
static IAggregateFunction* create(TArgs&&... args) {
|
||||
return create_base<CurryDirectAndData<AggregateFunctionTemplate, Data>>(
|
||||
std::forward<TArgs>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
using creator_with_integer_type = creator_with_type_base<true, false, false>;
|
||||
using creator_with_numeric_type = creator_with_type_base<true, true, false>;
|
||||
using creator_with_decimal_type = creator_with_type_base<false, false, true>;
|
||||
using creator_with_type = creator_with_type_base<true, true, true>;
|
||||
|
||||
struct creator_without_type {
|
||||
template <typename AggregateFunctionTemplate, typename... TArgs>
|
||||
static IAggregateFunction* create(const bool result_is_nullable,
|
||||
const DataTypes& argument_types, TArgs&&... args) {
|
||||
if (have_nullable(argument_types)) {
|
||||
IAggregateFunction* result = nullptr;
|
||||
if (argument_types.size() > 1) {
|
||||
std::visit(
|
||||
[&](auto result_is_nullable) {
|
||||
result = new AggregateFunctionNullVariadicInline<
|
||||
AggregateFunctionTemplate, result_is_nullable>(
|
||||
new AggregateFunctionTemplate(std::forward<TArgs>(args)...,
|
||||
remove_nullable(argument_types)),
|
||||
argument_types);
|
||||
},
|
||||
make_bool_variant(result_is_nullable));
|
||||
} else {
|
||||
std::visit(
|
||||
[&](auto result_is_nullable) {
|
||||
result = new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate,
|
||||
result_is_nullable>(
|
||||
new AggregateFunctionTemplate(std::forward<TArgs>(args)...,
|
||||
remove_nullable(argument_types)),
|
||||
argument_types);
|
||||
},
|
||||
make_bool_variant(result_is_nullable));
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
return new AggregateFunctionTemplate(std::forward<TArgs>(args)..., argument_types);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -255,8 +255,7 @@ inline const DataTypeDecimal<T>* check_decimal(const IDataType& data_type) {
|
||||
return typeid_cast<const DataTypeDecimal<T>*>(&data_type);
|
||||
}
|
||||
|
||||
inline UInt32 get_decimal_scale(const IDataType& data_type,
|
||||
UInt32 default_value = std::numeric_limits<UInt32>::max()) {
|
||||
inline UInt32 get_decimal_scale(const IDataType& data_type, UInt32 default_value = 0) {
|
||||
if (auto* decimal_type = check_decimal<Decimal32>(data_type)) return decimal_type->get_scale();
|
||||
if (auto* decimal_type = check_decimal<Decimal64>(data_type)) return decimal_type->get_scale();
|
||||
if (auto* decimal_type = check_decimal<Decimal128>(data_type)) return decimal_type->get_scale();
|
||||
|
||||
@ -161,13 +161,18 @@ DataTypePtr remove_nullable(const DataTypePtr& type) {
|
||||
DataTypes remove_nullable(const DataTypes& types) {
|
||||
DataTypes no_null_types;
|
||||
for (auto& type : types) {
|
||||
if (type->is_nullable()) {
|
||||
no_null_types.push_back(static_cast<const DataTypeNullable&>(*type).get_nested_type());
|
||||
} else {
|
||||
no_null_types.push_back(type);
|
||||
}
|
||||
no_null_types.push_back(remove_nullable(type));
|
||||
}
|
||||
return no_null_types;
|
||||
}
|
||||
|
||||
bool have_nullable(const DataTypes& types) {
|
||||
for (auto& type : types) {
|
||||
if (type->is_nullable()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -49,6 +49,10 @@ public:
|
||||
|
||||
bool equals(const IDataType& rhs) const override;
|
||||
|
||||
bool is_value_unambiguously_represented_in_contiguous_memory_region() const override {
|
||||
return nested_data_type->is_value_unambiguously_represented_in_contiguous_memory_region();
|
||||
}
|
||||
|
||||
bool get_is_parametric() const override { return true; }
|
||||
bool have_subtypes() const override { return true; }
|
||||
bool cannot_be_stored_in_tables() const override {
|
||||
@ -94,5 +98,6 @@ private:
|
||||
DataTypePtr make_nullable(const DataTypePtr& type);
|
||||
DataTypePtr remove_nullable(const DataTypePtr& type);
|
||||
DataTypes remove_nullable(const DataTypes& types);
|
||||
bool have_nullable(const DataTypes& types);
|
||||
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -335,7 +335,7 @@ Status get_least_supertype(const DataTypes& types, DataTypePtr* type, bool compa
|
||||
|
||||
UInt32 max_scale = 0;
|
||||
for (const auto& type : types) {
|
||||
UInt32 scale = get_decimal_scale(*type, 0);
|
||||
UInt32 scale = get_decimal_scale(*type);
|
||||
if (scale > max_scale) max_scale = scale;
|
||||
}
|
||||
|
||||
|
||||
@ -118,15 +118,8 @@ struct AggregateFunction {
|
||||
|
||||
static auto create(const DataTypePtr& data_type_ptr) -> AggregateFunctionPtr {
|
||||
DataTypes data_types = {remove_nullable(data_type_ptr)};
|
||||
auto& data_type = *data_types.front();
|
||||
AggregateFunctionPtr nested_function;
|
||||
if (is_decimal(data_types.front())) {
|
||||
nested_function = AggregateFunctionPtr(
|
||||
create_with_decimal_type<Function>(data_type, data_type, data_types));
|
||||
} else {
|
||||
nested_function =
|
||||
AggregateFunctionPtr(create_with_numeric_type<Function>(data_type, data_types));
|
||||
}
|
||||
nested_function.reset(creator_with_type::create<Function>(false, data_types));
|
||||
|
||||
AggregateFunctionPtr function;
|
||||
function.reset(new AggregateFunctionNullUnary<true>(nested_function,
|
||||
|
||||
Reference in New Issue
Block a user