[Feature](agg_state) support agg_state combinators (#19969)
support agg_state combinators state/merge/union
This commit is contained in:
@ -233,7 +233,7 @@ void MemTable::_aggregate_two_row_in_block(vectorized::MutableBlock& mutable_blo
|
||||
auto col_ptr = mutable_block.mutable_columns()[cid].get();
|
||||
_agg_functions[cid]->add(row_in_skiplist->agg_places(cid),
|
||||
const_cast<const doris::vectorized::IColumn**>(&col_ptr),
|
||||
new_row->_row_pos, nullptr);
|
||||
new_row->_row_pos, _arena.get());
|
||||
}
|
||||
}
|
||||
void MemTable::_put_into_output(vectorized::Block& in_block) {
|
||||
@ -298,7 +298,7 @@ void MemTable::_finalize_one_row(RowInBlock* row,
|
||||
} else {
|
||||
function->reset(agg_place);
|
||||
function->add(agg_place, const_cast<const doris::vectorized::IColumn**>(&col_ptr),
|
||||
row_pos, nullptr);
|
||||
row_pos, _arena.get());
|
||||
}
|
||||
}
|
||||
if constexpr (is_final) {
|
||||
@ -343,7 +343,7 @@ void MemTable::_aggregate() {
|
||||
_agg_functions[cid]->create(data);
|
||||
_agg_functions[cid]->add(
|
||||
data, const_cast<const doris::vectorized::IColumn**>(&col_ptr),
|
||||
prev_row->_row_pos, nullptr);
|
||||
prev_row->_row_pos, _arena.get());
|
||||
}
|
||||
}
|
||||
_stat.merged_rows++;
|
||||
|
||||
@ -39,6 +39,7 @@
|
||||
#include "runtime/thread_context.h"
|
||||
#include "tablet_meta.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_state_union.h"
|
||||
#include "vec/core/block.h"
|
||||
#include "vec/data_types/data_type.h"
|
||||
#include "vec/data_types/data_type_factory.hpp"
|
||||
@ -491,14 +492,21 @@ bool TabletColumn::is_row_store_column() const {
|
||||
return _col_name == BeConsts::ROW_STORE_COL;
|
||||
}
|
||||
|
||||
vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function_merge() const {
|
||||
vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function_union(
|
||||
vectorized::DataTypePtr type) const {
|
||||
auto state_type = dynamic_cast<const vectorized::DataTypeAggState*>(type.get());
|
||||
if (!state_type) {
|
||||
return nullptr;
|
||||
}
|
||||
vectorized::DataTypes argument_types;
|
||||
for (auto col : _sub_columns) {
|
||||
argument_types.push_back(vectorized::DataTypeFactory::instance().create_data_type(col));
|
||||
auto sub_type = vectorized::DataTypeFactory::instance().create_data_type(col);
|
||||
state_type->add_sub_type(sub_type);
|
||||
}
|
||||
auto function = vectorized::AggregateFunctionSimpleFactory::instance().get(
|
||||
_aggregation_name, argument_types, false);
|
||||
return function;
|
||||
auto agg_function = vectorized::AggregateFunctionSimpleFactory::instance().get(
|
||||
_aggregation_name, state_type->get_sub_types(), false);
|
||||
|
||||
return vectorized::AggregateStateUnion::create(agg_function, {type}, type);
|
||||
}
|
||||
|
||||
vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function(std::string suffix) const {
|
||||
@ -514,7 +522,7 @@ vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function(std::strin
|
||||
if (function) {
|
||||
return function;
|
||||
}
|
||||
return get_aggregate_function_merge();
|
||||
return get_aggregate_function_union(type);
|
||||
}
|
||||
|
||||
void TabletIndex::init_from_thrift(const TOlapTableIndex& index,
|
||||
|
||||
@ -89,7 +89,8 @@ public:
|
||||
void set_is_nullable(bool is_nullable) { _is_nullable = is_nullable; }
|
||||
void set_has_default_value(bool has) { _has_default_value = has; }
|
||||
FieldAggregationMethod aggregation() const { return _aggregation; }
|
||||
vectorized::AggregateFunctionPtr get_aggregate_function_merge() const;
|
||||
vectorized::AggregateFunctionPtr get_aggregate_function_union(
|
||||
vectorized::DataTypePtr type) const;
|
||||
vectorized::AggregateFunctionPtr get_aggregate_function(std::string suffix) const;
|
||||
int precision() const { return _precision; }
|
||||
int frac() const { return _frac; }
|
||||
|
||||
@ -59,7 +59,6 @@ struct TypeDescriptor {
|
||||
/// The maximum precision representable by a 8-byte decimal (Decimal8Value)
|
||||
static constexpr int MAX_DECIMAL8_PRECISION = 18;
|
||||
|
||||
// Empty for scalar types
|
||||
std::vector<TypeDescriptor> children;
|
||||
|
||||
// Only set if type == TYPE_STRUCT. The field name of each child.
|
||||
@ -148,6 +147,13 @@ struct TypeDescriptor {
|
||||
int idx = 0;
|
||||
TypeDescriptor result(t.types, &idx);
|
||||
DCHECK_EQ(idx, t.types.size() - 1);
|
||||
if (result.type == TYPE_AGG_STATE) {
|
||||
DCHECK(t.__isset.sub_types);
|
||||
for (auto sub : t.sub_types) {
|
||||
result.children.push_back(from_thrift(sub));
|
||||
result.contains_nulls.push_back(sub.is_nullable);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@ -134,7 +134,7 @@ public:
|
||||
MutableColumnPtr& dst, const size_t num_rows) const = 0;
|
||||
|
||||
virtual void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
|
||||
MutableColumnPtr& dst) const = 0;
|
||||
IColumn& to) const = 0;
|
||||
|
||||
/// Deserializes state. This function is called only for empty (just created) states.
|
||||
virtual void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
|
||||
@ -150,6 +150,10 @@ public:
|
||||
virtual void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
|
||||
Arena* arena) const = 0;
|
||||
|
||||
virtual void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
|
||||
const IColumn& column, size_t begin,
|
||||
size_t end, Arena* arena) const = 0;
|
||||
|
||||
virtual void deserialize_and_merge_from_column(AggregateDataPtr __restrict place,
|
||||
const IColumn& column, Arena* arena) const = 0;
|
||||
|
||||
@ -332,8 +336,8 @@ public:
|
||||
}
|
||||
|
||||
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
|
||||
MutableColumnPtr& dst) const override {
|
||||
VectorBufferWriter writter(assert_cast<ColumnString&>(*dst));
|
||||
IColumn& to) const override {
|
||||
VectorBufferWriter writter(assert_cast<ColumnString&>(to));
|
||||
assert_cast<const Derived*>(this)->serialize(place, writter);
|
||||
writter.commit();
|
||||
}
|
||||
@ -382,6 +386,38 @@ public:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
|
||||
const IColumn& column, size_t begin, size_t end,
|
||||
Arena* arena) const override {
|
||||
DCHECK(end <= column.size() && begin <= end)
|
||||
<< ", begin:" << begin << ", end:" << end << ", column.size():" << column.size();
|
||||
for (size_t i = begin; i <= end; ++i) {
|
||||
VectorBufferReader buffer_reader(
|
||||
(assert_cast<const ColumnString&>(column)).get_data_at(i));
|
||||
deserialize_and_merge(place, buffer_reader, arena);
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column,
|
||||
Arena* arena) const override {
|
||||
if (column.empty()) {
|
||||
return;
|
||||
}
|
||||
deserialize_and_merge_from_column_range(place, column, 0, column.size() - 1, arena);
|
||||
}
|
||||
|
||||
void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
|
||||
Arena* arena) const override {
|
||||
char deserialized_data[size_of_data()];
|
||||
AggregateDataPtr deserialized_place = (AggregateDataPtr)deserialized_data;
|
||||
|
||||
auto derived = static_cast<const Derived*>(this);
|
||||
derived->create(deserialized_place);
|
||||
derived->deserialize(deserialized_place, buf, arena);
|
||||
derived->merge(place, deserialized_place, arena);
|
||||
derived->destroy(deserialized_place);
|
||||
}
|
||||
};
|
||||
|
||||
/// Implements several methods for manipulation with data. T - type of structure with data for aggregation.
|
||||
@ -426,16 +462,6 @@ public:
|
||||
derived->deserialize(deserialized_place, buf, arena);
|
||||
derived->merge(place, deserialized_place, arena);
|
||||
}
|
||||
|
||||
void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column,
|
||||
Arena* arena) const override {
|
||||
size_t num_rows = column.size();
|
||||
for (size_t i = 0; i != num_rows; ++i) {
|
||||
VectorBufferReader buffer_reader(
|
||||
(assert_cast<const ColumnString&>(column)).get_data_at(i));
|
||||
deserialize_and_merge(place, buffer_reader, arena);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using AggregateFunctionPtr = std::shared_ptr<IAggregateFunction>;
|
||||
|
||||
@ -221,8 +221,8 @@ public:
|
||||
}
|
||||
|
||||
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
|
||||
MutableColumnPtr& dst) const override {
|
||||
auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
|
||||
IColumn& to) const override {
|
||||
auto& col = assert_cast<ColumnFixedLengthObject&>(to);
|
||||
col.set_item_size(sizeof(Data));
|
||||
col.resize(1);
|
||||
*reinterpret_cast<Data*>(col.get_data().data()) = this->data(place);
|
||||
|
||||
@ -104,7 +104,7 @@ public:
|
||||
col.resize(num_rows);
|
||||
auto* data = col.get_data().data();
|
||||
for (size_t i = 0; i != num_rows; ++i) {
|
||||
data[i] = this->data(places[i] + offset).count;
|
||||
data[i] = AggregateFunctionCount::data(places[i] + offset).count;
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,15 +120,16 @@ public:
|
||||
auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
|
||||
const size_t num_rows = column.size();
|
||||
for (size_t i = 0; i != num_rows; ++i) {
|
||||
this->data(place).count += data[i];
|
||||
AggregateFunctionCount::data(place).count += data[i];
|
||||
}
|
||||
}
|
||||
|
||||
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
|
||||
MutableColumnPtr& dst) const override {
|
||||
auto& col = assert_cast<ColumnUInt64&>(*dst);
|
||||
IColumn& to) const override {
|
||||
auto& col = assert_cast<ColumnUInt64&>(to);
|
||||
col.resize(1);
|
||||
reinterpret_cast<Data*>(col.get_data().data())->count = this->data(place).count;
|
||||
reinterpret_cast<Data*>(col.get_data().data())->count =
|
||||
AggregateFunctionCount::data(place).count;
|
||||
}
|
||||
|
||||
MutableColumnPtr create_serialize_column() const override {
|
||||
@ -199,7 +200,7 @@ public:
|
||||
col.resize(num_rows);
|
||||
auto* data = col.get_data().data();
|
||||
for (size_t i = 0; i != num_rows; ++i) {
|
||||
data[i] = this->data(places[i] + offset).count;
|
||||
data[i] = AggregateFunctionCountNotNullUnary::data(places[i] + offset).count;
|
||||
}
|
||||
}
|
||||
|
||||
@ -219,15 +220,16 @@ public:
|
||||
auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
|
||||
const size_t num_rows = column.size();
|
||||
for (size_t i = 0; i != num_rows; ++i) {
|
||||
this->data(place).count += data[i];
|
||||
AggregateFunctionCountNotNullUnary::data(place).count += data[i];
|
||||
}
|
||||
}
|
||||
|
||||
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
|
||||
MutableColumnPtr& dst) const override {
|
||||
auto& col = assert_cast<ColumnUInt64&>(*dst);
|
||||
IColumn& to) const override {
|
||||
auto& col = assert_cast<ColumnUInt64&>(to);
|
||||
col.resize(1);
|
||||
reinterpret_cast<Data*>(col.get_data().data())->count = this->data(place).count;
|
||||
reinterpret_cast<Data*>(col.get_data().data())->count =
|
||||
AggregateFunctionCountNotNullUnary::data(place).count;
|
||||
}
|
||||
|
||||
MutableColumnPtr create_serialize_column() const override {
|
||||
|
||||
@ -219,7 +219,6 @@ public:
|
||||
this->data(place).deserialize(buf, arena);
|
||||
}
|
||||
|
||||
// void insert_result_into(AggregateDataPtr place, IColumn & to, Arena * arena) const override
|
||||
void insert_result_into(ConstAggregateDataPtr targetplace, IColumn& to) const override {
|
||||
auto place = const_cast<AggregateDataPtr>(targetplace);
|
||||
auto arguments = this->data(place).get_arguments(this->argument_types);
|
||||
|
||||
@ -609,13 +609,13 @@ public:
|
||||
}
|
||||
|
||||
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
|
||||
MutableColumnPtr& dst) const override {
|
||||
IColumn& to) const override {
|
||||
if constexpr (Data::IsFixedLength) {
|
||||
auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
|
||||
auto& col = assert_cast<ColumnFixedLengthObject&>(to);
|
||||
col.resize(1);
|
||||
*reinterpret_cast<Data*>(col.get_data().data()) = this->data(place);
|
||||
} else {
|
||||
Base::serialize_without_key_to_column(place, dst);
|
||||
Base::serialize_without_key_to_column(place, to);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -155,16 +155,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column,
|
||||
Arena* arena) const override {
|
||||
size_t num_rows = column.size();
|
||||
for (size_t i = 0; i != num_rows; ++i) {
|
||||
VectorBufferReader buffer_reader(
|
||||
(assert_cast<const ColumnString&>(column)).get_data_at(i));
|
||||
deserialize_and_merge(place, buffer_reader, arena);
|
||||
}
|
||||
}
|
||||
|
||||
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
|
||||
if constexpr (result_is_nullable) {
|
||||
ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to);
|
||||
|
||||
@ -0,0 +1,50 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "vec/aggregate_functions/aggregate_function_state_union.h"
|
||||
|
||||
namespace doris::vectorized {
|
||||
const static std::string AGG_MERGE_SUFFIX = "_merge";
|
||||
|
||||
class AggregateStateMerge : public AggregateStateUnion {
|
||||
public:
|
||||
AggregateStateMerge(AggregateFunctionPtr function, const DataTypes& argument_types,
|
||||
const DataTypePtr& return_type)
|
||||
: AggregateStateUnion(function, argument_types, return_type) {}
|
||||
|
||||
static AggregateFunctionPtr create(AggregateFunctionPtr function,
|
||||
const DataTypes& argument_types,
|
||||
const DataTypePtr& return_type) {
|
||||
CHECK(argument_types.size() == 1);
|
||||
if (function == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<AggregateStateMerge>(function, argument_types, return_type);
|
||||
}
|
||||
|
||||
String get_name() const override { return _function->get_name() + AGG_MERGE_SUFFIX; }
|
||||
|
||||
DataTypePtr get_return_type() const override { return _function->get_return_type(); }
|
||||
|
||||
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
|
||||
_function->insert_result_into(place, to);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace doris::vectorized
|
||||
@ -0,0 +1,99 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "vec/aggregate_functions/aggregate_function.h"
|
||||
#include "vec/data_types/data_type_agg_state.h"
|
||||
|
||||
namespace doris::vectorized {
|
||||
const static std::string AGG_UNION_SUFFIX = "_union";
|
||||
|
||||
class AggregateStateUnion : public IAggregateFunctionHelper<AggregateStateUnion> {
|
||||
public:
|
||||
AggregateStateUnion(AggregateFunctionPtr function, const DataTypes& argument_types,
|
||||
const DataTypePtr& return_type)
|
||||
: IAggregateFunctionHelper(argument_types),
|
||||
_function(function),
|
||||
_return_type(return_type) {}
|
||||
~AggregateStateUnion() override = default;
|
||||
|
||||
static AggregateFunctionPtr create(AggregateFunctionPtr function,
|
||||
const DataTypes& argument_types,
|
||||
const DataTypePtr& return_type) {
|
||||
CHECK(argument_types.size() == 1);
|
||||
if (function == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<AggregateStateUnion>(function, argument_types, return_type);
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr __restrict place) const override { _function->create(place); }
|
||||
|
||||
String get_name() const override { return _function->get_name() + AGG_UNION_SUFFIX; }
|
||||
|
||||
DataTypePtr get_return_type() const override { return _return_type; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
|
||||
Arena* arena) const override {
|
||||
VectorBufferReader buffer_reader(
|
||||
(assert_cast<const ColumnString&>(*columns[0])).get_data_at(row_num));
|
||||
deserialize_and_merge(place, buffer_reader, arena);
|
||||
}
|
||||
|
||||
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
|
||||
Arena* arena) const override {
|
||||
_function->deserialize_and_merge_from_column_range(place, *columns[0], 0, batch_size - 1,
|
||||
arena);
|
||||
}
|
||||
|
||||
void reset(AggregateDataPtr place) const override { _function->reset(place); }
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
|
||||
Arena* arena) const override {
|
||||
_function->merge(place, rhs, arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
|
||||
_function->serialize(place, buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
|
||||
Arena* arena) const override {
|
||||
_function->deserialize(place, buf, arena);
|
||||
}
|
||||
|
||||
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
|
||||
_function->serialize_without_key_to_column(place, to);
|
||||
}
|
||||
|
||||
void destroy(AggregateDataPtr __restrict place) const noexcept override {
|
||||
_function->destroy(place);
|
||||
}
|
||||
|
||||
bool has_trivial_destructor() const override { return _function->has_trivial_destructor(); }
|
||||
|
||||
size_t size_of_data() const override { return _function->size_of_data(); }
|
||||
|
||||
size_t align_of_data() const override { return _function->align_of_data(); }
|
||||
|
||||
protected:
|
||||
AggregateFunctionPtr _function;
|
||||
DataTypePtr _return_type;
|
||||
};
|
||||
|
||||
} // namespace doris::vectorized
|
||||
@ -160,8 +160,8 @@ public:
|
||||
}
|
||||
|
||||
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
|
||||
MutableColumnPtr& dst) const override {
|
||||
auto& col = assert_cast<ColVecResult&>(*dst);
|
||||
IColumn& to) const override {
|
||||
auto& col = assert_cast<ColVecResult&>(to);
|
||||
col.resize(1);
|
||||
reinterpret_cast<Data*>(col.get_data().data())->sum = this->data(place).sum;
|
||||
}
|
||||
|
||||
@ -52,9 +52,9 @@ public:
|
||||
return TPrimitiveType::AGG_STATE;
|
||||
}
|
||||
|
||||
const DataTypes& get_sub_types() { return sub_types; }
|
||||
const DataTypes& get_sub_types() const { return sub_types; }
|
||||
|
||||
void add_sub_type(DataTypePtr type) { sub_types.push_back(type); }
|
||||
void add_sub_type(DataTypePtr type) const { sub_types.push_back(type); }
|
||||
|
||||
void to_pb_column_meta(PColumnMeta* col_meta) const override {
|
||||
IDataType::to_pb_column_meta(col_meta);
|
||||
@ -64,7 +64,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
DataTypes sub_types;
|
||||
mutable DataTypes sub_types;
|
||||
};
|
||||
|
||||
} // namespace doris::vectorized
|
||||
|
||||
@ -716,7 +716,7 @@ Status AggregationNode::_serialize_without_key(RuntimeState* state, Block* block
|
||||
|
||||
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
|
||||
_aggregate_evaluators[i]->function()->serialize_without_key_to_column(
|
||||
_agg_data->without_key + _offsets_of_aggregate_states[i], value_columns[i]);
|
||||
_agg_data->without_key + _offsets_of_aggregate_states[i], *value_columns[i]);
|
||||
}
|
||||
} else {
|
||||
std::vector<VectorBufferWriter> value_buffer_writers;
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#include "vec/exprs/table_function/vexplode_json_array.h"
|
||||
#include "vec/exprs/table_function/vexplode_numbers.h"
|
||||
#include "vec/exprs/table_function/vexplode_split.h"
|
||||
#include "vec/utils/util.hpp"
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
@ -61,17 +62,6 @@ const std::unordered_map<std::string, std::function<std::unique_ptr<TableFunctio
|
||||
|
||||
Status TableFunctionFactory::get_fn(const std::string& fn_name_raw, ObjectPool* pool,
|
||||
TableFunction** fn) {
|
||||
auto match_suffix = [](const std::string& name, const std::string& suffix) -> bool {
|
||||
if (name.length() < suffix.length()) {
|
||||
return false;
|
||||
}
|
||||
return name.substr(name.length() - suffix.length()) == suffix;
|
||||
};
|
||||
|
||||
auto remove_suffix = [](const std::string& name, const std::string& suffix) -> std::string {
|
||||
return name.substr(0, name.length() - suffix.length());
|
||||
};
|
||||
|
||||
bool is_outer = match_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER);
|
||||
std::string fn_name_real =
|
||||
is_outer ? remove_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER) : fn_name_raw;
|
||||
|
||||
@ -33,12 +33,16 @@
|
||||
#include "vec/aggregate_functions/aggregate_function_rpc.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_sort.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_state_merge.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_state_union.h"
|
||||
#include "vec/core/block.h"
|
||||
#include "vec/core/column_with_type_and_name.h"
|
||||
#include "vec/core/materialize_block.h"
|
||||
#include "vec/data_types/data_type_agg_state.h"
|
||||
#include "vec/data_types/data_type_factory.hpp"
|
||||
#include "vec/exprs/vexpr.h"
|
||||
#include "vec/exprs/vexpr_context.h"
|
||||
#include "vec/utils/util.hpp"
|
||||
|
||||
namespace doris {
|
||||
class RowDescriptor;
|
||||
@ -51,6 +55,17 @@ class IColumn;
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
template <class FunctionType>
|
||||
AggregateFunctionPtr get_agg_state_function(const std::string& name,
|
||||
const DataTypes& argument_types,
|
||||
DataTypePtr return_type) {
|
||||
return FunctionType::create(
|
||||
AggregateFunctionSimpleFactory::instance().get(
|
||||
name, ((DataTypeAggState*)argument_types[0].get())->get_sub_types(),
|
||||
return_type->is_nullable()),
|
||||
argument_types, return_type);
|
||||
}
|
||||
|
||||
AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
|
||||
: _fn(desc.fn),
|
||||
_is_merge(desc.agg_expr.is_merge_agg),
|
||||
@ -143,6 +158,19 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc,
|
||||
}
|
||||
} else if (_fn.binary_type == TFunctionBinaryType::RPC) {
|
||||
_function = AggregateRpcUdaf::create(_fn, argument_types, _data_type);
|
||||
} else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) {
|
||||
if (match_suffix(_fn.name.function_name, AGG_UNION_SUFFIX)) {
|
||||
_function = get_agg_state_function<AggregateStateUnion>(
|
||||
remove_suffix(_fn.name.function_name, AGG_UNION_SUFFIX), argument_types,
|
||||
_data_type);
|
||||
} else if (match_suffix(_fn.name.function_name, AGG_MERGE_SUFFIX)) {
|
||||
_function = get_agg_state_function<AggregateStateMerge>(
|
||||
remove_suffix(_fn.name.function_name, AGG_MERGE_SUFFIX), argument_types,
|
||||
_data_type);
|
||||
} else {
|
||||
return Status::InternalError(
|
||||
"Aggregate Function {} is not endwith '_merge' or '_union'", _fn.signature);
|
||||
}
|
||||
} else {
|
||||
_function = AggregateFunctionSimpleFactory::instance().get(
|
||||
_fn.name.function_name, argument_types, _data_type->is_nullable());
|
||||
|
||||
@ -32,15 +32,19 @@
|
||||
#include "common/status.h"
|
||||
#include "runtime/runtime_state.h"
|
||||
#include "udf/udf.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
|
||||
#include "vec/columns/column.h"
|
||||
#include "vec/core/block.h"
|
||||
#include "vec/core/column_with_type_and_name.h"
|
||||
#include "vec/core/columns_with_type_and_name.h"
|
||||
#include "vec/data_types/data_type.h"
|
||||
#include "vec/data_types/data_type_agg_state.h"
|
||||
#include "vec/exprs/vexpr_context.h"
|
||||
#include "vec/functions/function_agg_state.h"
|
||||
#include "vec/functions/function_java_udf.h"
|
||||
#include "vec/functions/function_rpc.h"
|
||||
#include "vec/functions/simple_function_factory.h"
|
||||
#include "vec/utils/util.hpp"
|
||||
|
||||
namespace doris {
|
||||
class RowDescriptor;
|
||||
@ -50,6 +54,8 @@ class TExprNode;
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
const std::string AGG_STATE_SUFFIX = "_state";
|
||||
|
||||
VectorizedFnCall::VectorizedFnCall(const TExprNode& node) : VExpr(node) {}
|
||||
|
||||
Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc,
|
||||
@ -62,6 +68,7 @@ Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc,
|
||||
argument_template.emplace_back(nullptr, child->data_type(), child->expr_name());
|
||||
child_expr_name.emplace_back(child->expr_name());
|
||||
}
|
||||
|
||||
if (_fn.binary_type == TFunctionBinaryType::RPC) {
|
||||
_function = FunctionRPC::create(_fn, argument_template, _data_type);
|
||||
} else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) {
|
||||
@ -72,6 +79,29 @@ Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc,
|
||||
"Java UDF is not enabled, you can change be config enable_java_support to true "
|
||||
"and restart be.");
|
||||
}
|
||||
} else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) {
|
||||
DataTypes argument_types;
|
||||
for (auto column : argument_template) {
|
||||
argument_types.emplace_back(column.type);
|
||||
}
|
||||
|
||||
if (match_suffix(_fn.name.function_name, AGG_STATE_SUFFIX)) {
|
||||
if (_data_type->is_nullable()) {
|
||||
return Status::InternalError("State function's return type must be not nullable");
|
||||
}
|
||||
if (_data_type->get_type_as_primitive_type() != PrimitiveType::TYPE_AGG_STATE) {
|
||||
return Status::InternalError(
|
||||
"State function's return type must be agg_state but get {}",
|
||||
_data_type->get_family_name());
|
||||
}
|
||||
_function = FunctionAggState::create(
|
||||
argument_types, _data_type,
|
||||
AggregateFunctionSimpleFactory::instance().get(
|
||||
remove_suffix(_fn.name.function_name, AGG_STATE_SUFFIX), argument_types,
|
||||
_data_type->is_nullable()));
|
||||
} else {
|
||||
return Status::InternalError("Function {} is not endwith '_state'", _fn.signature);
|
||||
}
|
||||
} else {
|
||||
// get the function. won't prepare function.
|
||||
_function = SimpleFunctionFactory::instance().get_function(
|
||||
|
||||
85
be/src/vec/functions/function_agg_state.h
Normal file
85
be/src/vec/functions/function_agg_state.h
Normal file
@ -0,0 +1,85 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "common/status.h"
|
||||
#include "vec/aggregate_functions/aggregate_function.h"
|
||||
#include "vec/common/arena.h"
|
||||
#include "vec/core/block.h"
|
||||
#include "vec/core/types.h"
|
||||
#include "vec/data_types/data_type.h"
|
||||
#include "vec/functions/function.h"
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
class FunctionAggState : public IFunction {
|
||||
public:
|
||||
FunctionAggState(const DataTypes& argument_types, const DataTypePtr& return_type,
|
||||
AggregateFunctionPtr agg_function)
|
||||
: _argument_types(argument_types),
|
||||
_return_type(return_type),
|
||||
_agg_function(agg_function) {}
|
||||
|
||||
static FunctionBasePtr create(const DataTypes& argument_types, const DataTypePtr& return_type,
|
||||
AggregateFunctionPtr agg_function) {
|
||||
if (agg_function == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<DefaultFunction>(
|
||||
std::make_shared<FunctionAggState>(argument_types, return_type, agg_function),
|
||||
argument_types, return_type);
|
||||
}
|
||||
|
||||
size_t get_number_of_arguments() const override { return _argument_types.size(); }
|
||||
|
||||
bool use_default_implementation_for_constants() const override { return true; }
|
||||
bool use_default_implementation_for_nulls() const override { return false; }
|
||||
|
||||
String get_name() const override { return fmt::format("{}_state", _agg_function->get_name()); }
|
||||
|
||||
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
|
||||
return _return_type;
|
||||
}
|
||||
|
||||
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
|
||||
size_t result, size_t input_rows_count) override {
|
||||
auto col = _return_type->create_column();
|
||||
std::vector<const IColumn*> agg_columns;
|
||||
|
||||
for (size_t index : arguments) {
|
||||
agg_columns.push_back(block.get_by_position(index).column);
|
||||
}
|
||||
|
||||
VectorBufferWriter writter(assert_cast<ColumnString&>(*col));
|
||||
_agg_function->streaming_agg_serialize(agg_columns.data(), writter, input_rows_count,
|
||||
&arena);
|
||||
|
||||
block.replace_by_position(result, std::move(col));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
DataTypes _argument_types;
|
||||
DataTypePtr _return_type;
|
||||
AggregateFunctionPtr _agg_function;
|
||||
Arena arena;
|
||||
};
|
||||
|
||||
} // namespace doris::vectorized
|
||||
@ -465,7 +465,7 @@ void BlockReader::_update_agg_value(MutableColumns& columns, int begin, int end,
|
||||
|
||||
if (begin <= end) {
|
||||
function->add_batch_range(begin, end, place, const_cast<const IColumn**>(&column_ptr),
|
||||
nullptr, _stored_has_null_tag[idx]);
|
||||
&_arena, _stored_has_null_tag[idx]);
|
||||
}
|
||||
|
||||
if (is_close) {
|
||||
|
||||
@ -102,6 +102,19 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
inline bool match_suffix(const std::string& name, const std::string& suffix) {
|
||||
if (name.length() < suffix.length()) {
|
||||
return false;
|
||||
}
|
||||
return name.substr(name.length() - suffix.length()) == suffix;
|
||||
}
|
||||
|
||||
inline std::string remove_suffix(const std::string& name, const std::string& suffix) {
|
||||
CHECK(match_suffix(name, suffix))
|
||||
<< ", suffix not match, name=" << name << ", suffix=" << suffix;
|
||||
return name.substr(0, name.length() - suffix.length());
|
||||
};
|
||||
|
||||
} // namespace doris::vectorized
|
||||
|
||||
namespace apache::thrift {
|
||||
|
||||
@ -66,8 +66,8 @@ public class BuiltinAggregateFunction extends Function {
|
||||
}
|
||||
|
||||
@Override
|
||||
public TFunction toThrift(Type realReturnType, Type[] realArgTypes) {
|
||||
TFunction fn = super.toThrift(realReturnType, realArgTypes);
|
||||
public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) {
|
||||
TFunction fn = super.toThrift(realReturnType, realArgTypes, realArgTypeNullables);
|
||||
// TODO: for now, just put the op_ enum as the id.
|
||||
if (op == BuiltinAggregateFunction.Operator.FIRST_VALUE_REWRITE) {
|
||||
fn.setId(0);
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
|
||||
package org.apache.doris.analysis;
|
||||
|
||||
import org.apache.doris.catalog.AggregateType;
|
||||
import org.apache.doris.catalog.Column;
|
||||
import org.apache.doris.catalog.DatabaseIf;
|
||||
import org.apache.doris.catalog.Env;
|
||||
@ -200,11 +199,7 @@ public class DescribeStmt extends ShowStmt {
|
||||
// Extra string (aggregation and bloom filter)
|
||||
List<String> extras = Lists.newArrayList();
|
||||
if (column.getAggregationType() != null) {
|
||||
if (column.getAggregationType() == AggregateType.GENERIC_AGGREGATION) {
|
||||
extras.add(column.getGenericAggregationString());
|
||||
} else {
|
||||
extras.add(column.getAggregationType().name());
|
||||
}
|
||||
extras.add(column.getAggregationString());
|
||||
}
|
||||
if (bfColumns != null && bfColumns.contains(column.getName())) {
|
||||
extras.add("BLOOM_FILTER");
|
||||
|
||||
@ -21,12 +21,15 @@
|
||||
package org.apache.doris.analysis;
|
||||
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.catalog.AggregateFunction;
|
||||
import org.apache.doris.catalog.ArrayType;
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.catalog.Function;
|
||||
import org.apache.doris.catalog.Function.NullableMode;
|
||||
import org.apache.doris.catalog.FunctionSet;
|
||||
import org.apache.doris.catalog.MaterializedIndexMeta;
|
||||
import org.apache.doris.catalog.PrimitiveType;
|
||||
import org.apache.doris.catalog.ScalarFunction;
|
||||
import org.apache.doris.catalog.ScalarType;
|
||||
import org.apache.doris.catalog.Type;
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
@ -40,6 +43,7 @@ import org.apache.doris.statistics.ExprStats;
|
||||
import org.apache.doris.thrift.TExpr;
|
||||
import org.apache.doris.thrift.TExprNode;
|
||||
import org.apache.doris.thrift.TExprOpcode;
|
||||
import org.apache.doris.thrift.TFunctionBinaryType;
|
||||
|
||||
import com.google.common.base.Joiner;
|
||||
import com.google.common.base.MoreObjects;
|
||||
@ -75,6 +79,10 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
// supports negation.
|
||||
private static final String NEGATE_FN = "negate";
|
||||
|
||||
public static final String AGG_STATE_SUFFIX = "_state";
|
||||
public static final String AGG_UNION_SUFFIX = "_union";
|
||||
public static final String AGG_MERGE_SUFFIX = "_merge";
|
||||
|
||||
protected boolean disableTableName = false;
|
||||
|
||||
// to be used where we can't come up with a better estimate
|
||||
@ -436,6 +444,10 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
setSelectivity();
|
||||
}
|
||||
analysisDone();
|
||||
if (type.isAggStateType() && !(this instanceof SlotRef) && ((ScalarType) type).getSubTypes() == null) {
|
||||
type = new ScalarType(Arrays.asList(collectChildReturnTypes()),
|
||||
Arrays.asList(collectChildReturnNullables()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -482,6 +494,14 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
return childTypes;
|
||||
}
|
||||
|
||||
protected Boolean[] collectChildReturnNullables() {
|
||||
Boolean[] childNullables = new Boolean[children.size()];
|
||||
for (int i = 0; i < children.size(); ++i) {
|
||||
childNullables[i] = children.get(i).isNullable();
|
||||
}
|
||||
return childNullables;
|
||||
}
|
||||
|
||||
public List<Expr> getChildrenWithoutCast() {
|
||||
List<Expr> result = new ArrayList<>();
|
||||
for (int i = 0; i < children.size(); ++i) {
|
||||
@ -984,10 +1004,14 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
// Append a flattened version of this expr, including all children, to 'container'.
|
||||
protected void treeToThriftHelper(TExpr container) {
|
||||
TExprNode msg = new TExprNode();
|
||||
if (type.isAggStateType() && ((ScalarType) type).getSubTypes() == null) {
|
||||
type = new ScalarType(Arrays.asList(collectChildReturnTypes()),
|
||||
Arrays.asList(collectChildReturnNullables()));
|
||||
}
|
||||
msg.type = type.toThrift();
|
||||
msg.num_children = children.size();
|
||||
if (fn != null) {
|
||||
msg.setFn(fn.toThrift(type, collectChildReturnTypes()));
|
||||
msg.setFn(fn.toThrift(type, collectChildReturnTypes(), collectChildReturnNullables()));
|
||||
if (fn.hasVarArgs()) {
|
||||
msg.setVarargStartIdx(fn.getNumArgs() - 1);
|
||||
}
|
||||
@ -1382,7 +1406,8 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
}
|
||||
|
||||
public Expr checkTypeCompatibility(Type targetType) throws AnalysisException {
|
||||
if (!targetType.isComplexType() && targetType.getPrimitiveType() == type.getPrimitiveType()) {
|
||||
if (!targetType.isComplexType() && !targetType.isAggStateType()
|
||||
&& targetType.getPrimitiveType() == type.getPrimitiveType()) {
|
||||
if (targetType.isDecimalV2() && type.isDecimalV2()) {
|
||||
return this;
|
||||
} else if (!PrimitiveType.typeWithPrecision.contains(type.getPrimitiveType())) {
|
||||
@ -1392,6 +1417,10 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
return this;
|
||||
}
|
||||
}
|
||||
if (type.isAggStateType() != targetType.isAggStateType()) {
|
||||
throw new AnalysisException("AggState can't cast from other type.");
|
||||
}
|
||||
|
||||
// bitmap must match exactly
|
||||
if (targetType.getPrimitiveType() == PrimitiveType.BITMAP) {
|
||||
throw new AnalysisException("bitmap column require the function return type is BITMAP");
|
||||
@ -1452,8 +1481,32 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
return this;
|
||||
}
|
||||
|
||||
if (this.type.equals(targetType)) {
|
||||
return this;
|
||||
if (this.type.isAggStateType()) {
|
||||
List<Type> subTypes = ((ScalarType) targetType).getSubTypes();
|
||||
|
||||
if (this instanceof FunctionCallExpr) {
|
||||
if (subTypes.size() != getChildren().size()) {
|
||||
throw new AnalysisException("AggState's subTypes size not euqal to children number");
|
||||
}
|
||||
for (int i = 0; i < subTypes.size(); i++) {
|
||||
setChild(i, getChild(i).castTo(subTypes.get(i)));
|
||||
}
|
||||
type = targetType;
|
||||
} else {
|
||||
List<Type> selfSubTypes = ((ScalarType) type).getSubTypes();
|
||||
if (subTypes.size() != selfSubTypes.size()) {
|
||||
throw new AnalysisException("AggState's subTypes size did not match");
|
||||
}
|
||||
for (int i = 0; i < subTypes.size(); i++) {
|
||||
if (subTypes.get(i) != selfSubTypes.get(i)) {
|
||||
throw new AnalysisException("AggState's subType did not match");
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (this.type.equals(targetType)) {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
if (targetType.getPrimitiveType() == PrimitiveType.DECIMALV2
|
||||
@ -1793,14 +1846,72 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
throws AnalysisException {
|
||||
FunctionName fnName = new FunctionName(name);
|
||||
Function searchDesc = new Function(fnName, Arrays.asList(getActualArgTypes(argTypes)), Type.INVALID, false,
|
||||
VectorizedUtil.isVectorized());
|
||||
true);
|
||||
Function f = Env.getCurrentEnv().getFunction(searchDesc, mode);
|
||||
if (f != null && fnName.getFunction().equalsIgnoreCase("rand")) {
|
||||
if (this.children.size() == 1
|
||||
&& !(this.children.get(0) instanceof LiteralExpr)) {
|
||||
if (this.children.size() == 1 && !(this.children.get(0) instanceof LiteralExpr)) {
|
||||
throw new AnalysisException("The param of rand function must be literal");
|
||||
}
|
||||
}
|
||||
if (f != null) {
|
||||
return f;
|
||||
}
|
||||
|
||||
boolean isUnion = name.toLowerCase().endsWith(AGG_UNION_SUFFIX);
|
||||
boolean isMerge = name.toLowerCase().endsWith(AGG_MERGE_SUFFIX);
|
||||
boolean isState = name.toLowerCase().endsWith(AGG_STATE_SUFFIX);
|
||||
if (isUnion || isMerge || isState) {
|
||||
if (isUnion) {
|
||||
name = name.substring(0, name.length() - AGG_UNION_SUFFIX.length());
|
||||
}
|
||||
if (isMerge) {
|
||||
name = name.substring(0, name.length() - AGG_MERGE_SUFFIX.length());
|
||||
}
|
||||
if (isState) {
|
||||
name = name.substring(0, name.length() - AGG_STATE_SUFFIX.length());
|
||||
}
|
||||
|
||||
List<Type> argList = Arrays.asList(getActualArgTypes(argTypes));
|
||||
List<Type> nestedArgList;
|
||||
if (isState) {
|
||||
nestedArgList = argList;
|
||||
} else {
|
||||
if (argList.size() != 1 || !argList.get(0).isAggStateType()) {
|
||||
throw new AnalysisException("merge/union function must input one agg_state");
|
||||
}
|
||||
ScalarType aggState = (ScalarType) argList.get(0);
|
||||
if (aggState.getSubTypes() == null) {
|
||||
throw new AnalysisException("agg_state's subTypes is null");
|
||||
}
|
||||
nestedArgList = aggState.getSubTypes();
|
||||
}
|
||||
|
||||
searchDesc = new Function(new FunctionName(name), nestedArgList, Type.INVALID, false, true);
|
||||
|
||||
f = Env.getCurrentEnv().getFunction(searchDesc, mode);
|
||||
if (f == null || !(f instanceof AggregateFunction)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isState) {
|
||||
f = new ScalarFunction(new FunctionName(name + AGG_STATE_SUFFIX), Arrays.asList(f.getArgs()),
|
||||
Type.AGG_STATE, f.hasVarArgs(), f.isUserVisible());
|
||||
f.setNullableMode(NullableMode.ALWAYS_NOT_NULLABLE);
|
||||
} else {
|
||||
f = ((AggregateFunction) f).clone();
|
||||
f.setArgs(argList);
|
||||
if (isUnion) {
|
||||
f.setName(new FunctionName(name + AGG_UNION_SUFFIX));
|
||||
f.setReturnType((ScalarType) argList.get(0));
|
||||
f.setNullableMode(NullableMode.ALWAYS_NOT_NULLABLE);
|
||||
}
|
||||
if (isMerge) {
|
||||
f.setName(new FunctionName(name + AGG_MERGE_SUFFIX));
|
||||
}
|
||||
}
|
||||
f.setBinaryType(TFunctionBinaryType.AGG_STATE);
|
||||
}
|
||||
|
||||
return f;
|
||||
}
|
||||
|
||||
|
||||
@ -589,8 +589,8 @@ public class AggregateFunction extends Function {
|
||||
}
|
||||
|
||||
@Override
|
||||
public TFunction toThrift(Type realReturnType, Type[] realArgTypes) {
|
||||
TFunction fn = super.toThrift(realReturnType, realArgTypes);
|
||||
public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) {
|
||||
TFunction fn = super.toThrift(realReturnType, realArgTypes, realArgTypeNullables);
|
||||
TAggregateFunction aggFn = new TAggregateFunction();
|
||||
aggFn.setIsAnalyticOnlyFn(isAnalyticFn && !isAggregateFn);
|
||||
aggFn.setUpdateFnSymbol(updateFnSymbol);
|
||||
|
||||
@ -397,6 +397,14 @@ public class Column implements Writable, GsonPostProcessable {
|
||||
return this.aggregationType;
|
||||
}
|
||||
|
||||
public String getAggregationString() {
|
||||
if (getAggregationType() == AggregateType.GENERIC_AGGREGATION) {
|
||||
return getGenericAggregationString();
|
||||
} else {
|
||||
return getAggregationType().name();
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isAggregated() {
|
||||
return aggregationType != null && aggregationType != AggregateType.NONE;
|
||||
}
|
||||
|
||||
@ -17,6 +17,8 @@
|
||||
|
||||
package org.apache.doris.catalog;
|
||||
|
||||
import org.apache.doris.analysis.Expr;
|
||||
import org.apache.doris.analysis.FunctionCallExpr;
|
||||
import org.apache.doris.analysis.FunctionName;
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
import org.apache.doris.common.UserException;
|
||||
@ -42,6 +44,7 @@ import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Base class for all functions.
|
||||
@ -217,6 +220,10 @@ public class Function implements Writable {
|
||||
return argTypes;
|
||||
}
|
||||
|
||||
public void setArgs(List<Type> argTypes) {
|
||||
this.argTypes = argTypes.toArray(new Type[argTypes.size()]);
|
||||
}
|
||||
|
||||
// Returns the number of arguments to this function.
|
||||
public int getNumArgs() {
|
||||
return argTypes.length;
|
||||
@ -506,7 +513,7 @@ public class Function implements Writable {
|
||||
return retType instanceof AnyType;
|
||||
}
|
||||
|
||||
public TFunction toThrift(Type realReturnType, Type[] realArgTypes) {
|
||||
public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) {
|
||||
TFunction fn = new TFunction();
|
||||
fn.setSignature(signatureString());
|
||||
fn.setName(name.toThrift());
|
||||
@ -514,16 +521,24 @@ public class Function implements Writable {
|
||||
if (location != null) {
|
||||
fn.setHdfsLocation(location.getLocation());
|
||||
}
|
||||
// `realArgTypes.length != argTypes.length` is true iff this is an aggregation function.
|
||||
// For aggregation functions, `argTypes` here is already its real type with true precision and scale.
|
||||
// `realArgTypes.length != argTypes.length` is true iff this is an aggregation
|
||||
// function.
|
||||
// For aggregation functions, `argTypes` here is already its real type with true
|
||||
// precision and scale.
|
||||
if (realArgTypes.length != argTypes.length) {
|
||||
fn.setArgTypes(Type.toThrift(Lists.newArrayList(argTypes)));
|
||||
} else {
|
||||
fn.setArgTypes(Type.toThrift(Lists.newArrayList(argTypes), Lists.newArrayList(realArgTypes)));
|
||||
}
|
||||
// For types with different precisions and scales, return type only indicates a type with default
|
||||
|
||||
if (realReturnType.isAggStateType()) {
|
||||
realReturnType = new ScalarType(Arrays.asList(realArgTypes), Arrays.asList(realArgTypeNullables));
|
||||
}
|
||||
|
||||
// For types with different precisions and scales, return type only indicates a
|
||||
// type with default
|
||||
// precision and scale so we need to transform it to the correct type.
|
||||
if (realReturnType.typeContainsPrecision()) {
|
||||
if (realReturnType.typeContainsPrecision() || realReturnType.isAggStateType()) {
|
||||
fn.setRetType(realReturnType.toThrift());
|
||||
} else {
|
||||
fn.setRetType(getReturnType().toThrift());
|
||||
@ -816,4 +831,18 @@ public class Function implements Writable {
|
||||
result = 31 * result + Arrays.hashCode(argTypes);
|
||||
return result;
|
||||
}
|
||||
|
||||
public static FunctionCallExpr convertToStateCombinator(FunctionCallExpr fnCall) {
|
||||
Function aggFunction = fnCall.getFn();
|
||||
List<Type> arguments = Arrays.asList(aggFunction.getArgs());
|
||||
ScalarFunction fn = new org.apache.doris.catalog.ScalarFunction(
|
||||
new FunctionName(aggFunction.getFunctionName().getFunction() + Expr.AGG_STATE_SUFFIX),
|
||||
arguments,
|
||||
new ScalarType(arguments, fnCall.getChildren().stream().map(expr -> {
|
||||
return expr.isNullable();
|
||||
}).collect(Collectors.toList())), aggFunction.hasVarArgs(), aggFunction.isUserVisible());
|
||||
fn.setNullableMode(NullableMode.ALWAYS_NOT_NULLABLE);
|
||||
fn.setBinaryType(TFunctionBinaryType.AGG_STATE);
|
||||
return new FunctionCallExpr(fn, fnCall.getParams());
|
||||
}
|
||||
}
|
||||
|
||||
@ -229,10 +229,10 @@ public class ScalarFunction extends Function {
|
||||
}
|
||||
|
||||
@Override
|
||||
public TFunction toThrift(Type realReturnType, Type[] realArgTypes) {
|
||||
TFunction fn = super.toThrift(realReturnType, realArgTypes);
|
||||
public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) {
|
||||
TFunction fn = super.toThrift(realReturnType, realArgTypes, realArgTypeNullables);
|
||||
fn.setScalarFn(new TScalarFunction());
|
||||
if (getBinaryType() != TFunctionBinaryType.BUILTIN) {
|
||||
if (getBinaryType() == TFunctionBinaryType.JAVA_UDF || getBinaryType() == TFunctionBinaryType.RPC) {
|
||||
fn.getScalarFn().setSymbol(symbolName);
|
||||
} else {
|
||||
fn.getScalarFn().setSymbol("");
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
|
||||
package org.apache.doris.common.proc;
|
||||
|
||||
import org.apache.doris.catalog.AggregateType;
|
||||
import org.apache.doris.catalog.Column;
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
import org.apache.doris.common.FeConstants;
|
||||
@ -60,11 +59,7 @@ public class IndexSchemaProcNode implements ProcNodeInterface {
|
||||
// Extra string (aggregation and bloom filter)
|
||||
List<String> extras = Lists.newArrayList();
|
||||
if (column.getAggregationType() != null) {
|
||||
if (column.getAggregationType() == AggregateType.GENERIC_AGGREGATION) {
|
||||
extras.add(column.getGenericAggregationString());
|
||||
} else {
|
||||
extras.add(column.getAggregationType().name());
|
||||
}
|
||||
extras.add(column.getAggregationString());
|
||||
}
|
||||
if (bfColumns != null && bfColumns.contains(column.getName())) {
|
||||
extras.add("BLOOM_FILTER");
|
||||
|
||||
Reference in New Issue
Block a user