[improvement](function) improve date_trunc function performance when timeunit is const (#25824)
this PR #22602 have check function. only support date_trunc(column, const), so the second must be const literal and no need to check time unit every row.
This commit is contained in:
@ -24,6 +24,7 @@
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -378,6 +379,12 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
struct DateTruncState {
|
||||
using Callback_function =
|
||||
std::function<void(const ColumnPtr&, ColumnPtr& res, NullMap& null_map, size_t)>;
|
||||
Callback_function callback_function;
|
||||
};
|
||||
|
||||
template <typename DateType>
|
||||
struct DateTrunc {
|
||||
static constexpr auto name = "date_trunc";
|
||||
@ -396,88 +403,73 @@ struct DateTrunc {
|
||||
return make_nullable(std::make_shared<DateType>());
|
||||
}
|
||||
|
||||
static Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
|
||||
if (scope != FunctionContext::THREAD_LOCAL) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (!context->is_col_constant(1)) {
|
||||
return Status::InvalidArgument(
|
||||
"date_trunc function of time unit argument must be constant.");
|
||||
}
|
||||
const auto& data_str = context->get_constant_col(1)->column_ptr->get_data_at(0);
|
||||
std::string lower_str(data_str.data, data_str.size);
|
||||
std::transform(lower_str.begin(), lower_str.end(), lower_str.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
|
||||
std::shared_ptr<DateTruncState> state = std::make_shared<DateTruncState>();
|
||||
if (std::strncmp("year", lower_str.data(), 4) == 0) {
|
||||
state->callback_function = &execute_impl_right_const<TimeUnit::YEAR>;
|
||||
} else if (std::strncmp("quarter", lower_str.data(), 7) == 0) {
|
||||
state->callback_function = &execute_impl_right_const<TimeUnit::QUARTER>;
|
||||
} else if (std::strncmp("month", lower_str.data(), 5) == 0) {
|
||||
state->callback_function = &execute_impl_right_const<TimeUnit::MONTH>;
|
||||
} else if (std::strncmp("week", lower_str.data(), 4) == 0) {
|
||||
state->callback_function = &execute_impl_right_const<TimeUnit::WEEK>;
|
||||
} else if (std::strncmp("day", lower_str.data(), 3) == 0) {
|
||||
state->callback_function = &execute_impl_right_const<TimeUnit::DAY>;
|
||||
} else if (std::strncmp("hour", lower_str.data(), 4) == 0) {
|
||||
state->callback_function = &execute_impl_right_const<TimeUnit::HOUR>;
|
||||
} else if (std::strncmp("minute", lower_str.data(), 6) == 0) {
|
||||
state->callback_function = &execute_impl_right_const<TimeUnit::MINUTE>;
|
||||
} else if (std::strncmp("second", lower_str.data(), 6) == 0) {
|
||||
state->callback_function = &execute_impl_right_const<TimeUnit::SECOND>;
|
||||
} else {
|
||||
return Status::RuntimeError(
|
||||
"Illegal second argument column of function date_trunc. now only support "
|
||||
"[second,minute,hour,day,week,month,quarter,year]");
|
||||
}
|
||||
context->set_function_state(scope, state);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
|
||||
size_t result, size_t input_rows_count) {
|
||||
DCHECK_EQ(arguments.size(), 2);
|
||||
|
||||
auto null_map = ColumnUInt8::create(input_rows_count, 0);
|
||||
const auto& col0 = block.get_by_position(arguments[0]).column;
|
||||
bool col_const[2] = {is_column_const(*col0)};
|
||||
ColumnPtr argument_columns[2] = {
|
||||
col_const[0] ? static_cast<const ColumnConst&>(*col0).convert_to_full_column()
|
||||
: col0};
|
||||
|
||||
std::tie(argument_columns[1], col_const[1]) =
|
||||
unpack_if_const(block.get_by_position(arguments[1]).column);
|
||||
|
||||
auto datetime_column = static_cast<const ColumnType*>(argument_columns[0].get());
|
||||
auto str_column = static_cast<const ColumnString*>(argument_columns[1].get());
|
||||
auto& rdata = str_column->get_chars();
|
||||
auto& roffsets = str_column->get_offsets();
|
||||
|
||||
ColumnPtr res = ColumnType::create();
|
||||
if (col_const[1]) {
|
||||
execute_impl_right_const(
|
||||
datetime_column->get_data(), str_column->get_data_at(0),
|
||||
static_cast<ColumnType*>(res->assume_mutable().get())->get_data(),
|
||||
null_map->get_data(), input_rows_count);
|
||||
} else {
|
||||
execute_impl(datetime_column->get_data(), rdata, roffsets,
|
||||
static_cast<ColumnType*>(res->assume_mutable().get())->get_data(),
|
||||
null_map->get_data(), input_rows_count);
|
||||
}
|
||||
|
||||
const auto& datetime_column =
|
||||
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
|
||||
ColumnPtr res = ColumnType::create(input_rows_count);
|
||||
auto* state = reinterpret_cast<DateTruncState*>(
|
||||
context->get_function_state(FunctionContext::THREAD_LOCAL));
|
||||
DCHECK(state != nullptr);
|
||||
state->callback_function(datetime_column, res, null_map->get_data(), input_rows_count);
|
||||
block.get_by_position(result).column = ColumnNullable::create(res, std::move(null_map));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
static void execute_impl(const PaddedPODArray<ArgType>& ldata, const ColumnString::Chars& rdata,
|
||||
const ColumnString::Offsets& roffsets, PaddedPODArray<ArgType>& res,
|
||||
NullMap& null_map, size_t input_rows_count) {
|
||||
res.resize(input_rows_count);
|
||||
for (size_t i = 0; i < input_rows_count; ++i) {
|
||||
auto dt = binary_cast<ArgType, DateValueType>(ldata[i]);
|
||||
const char* str_data = reinterpret_cast<const char*>(&rdata[roffsets[i - 1]]);
|
||||
_execute_inner_loop(dt, str_data, res, null_map, i);
|
||||
}
|
||||
}
|
||||
static void execute_impl_right_const(const PaddedPODArray<ArgType>& ldata,
|
||||
const StringRef& rdata, PaddedPODArray<ArgType>& res,
|
||||
template <TimeUnit Unit>
|
||||
static void execute_impl_right_const(const ColumnPtr& datetime_column, ColumnPtr& result_column,
|
||||
NullMap& null_map, size_t input_rows_count) {
|
||||
res.resize(input_rows_count);
|
||||
std::string lower_str(rdata.data, rdata.size);
|
||||
std::transform(lower_str.begin(), lower_str.end(), lower_str.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
auto& data = static_cast<const ColumnType*>(datetime_column.get())->get_data();
|
||||
auto& res = static_cast<ColumnType*>(result_column->assume_mutable().get())->get_data();
|
||||
for (size_t i = 0; i < input_rows_count; ++i) {
|
||||
auto dt = binary_cast<ArgType, DateValueType>(ldata[i]);
|
||||
_execute_inner_loop(dt, lower_str.data(), res, null_map, i);
|
||||
auto dt = binary_cast<ArgType, DateValueType>(data[i]);
|
||||
null_map[i] = !dt.template datetime_trunc<Unit>();
|
||||
res[i] = binary_cast<DateValueType, ArgType>(dt);
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
static void _execute_inner_loop(T& dt, const char* str_data, PaddedPODArray<ArgType>& res,
|
||||
NullMap& null_map, size_t index) {
|
||||
if (std::strncmp("year", str_data, 4) == 0) {
|
||||
null_map[index] = !dt.template datetime_trunc<YEAR>();
|
||||
} else if (std::strncmp("quarter", str_data, 7) == 0) {
|
||||
null_map[index] = !dt.template datetime_trunc<QUARTER>();
|
||||
} else if (std::strncmp("month", str_data, 5) == 0) {
|
||||
null_map[index] = !dt.template datetime_trunc<MONTH>();
|
||||
} else if (std::strncmp("week", str_data, 4) == 0) {
|
||||
null_map[index] = !dt.template datetime_trunc<WEEK>();
|
||||
} else if (std::strncmp("day", str_data, 3) == 0) {
|
||||
null_map[index] = !dt.template datetime_trunc<DAY>();
|
||||
} else if (std::strncmp("hour", str_data, 4) == 0) {
|
||||
null_map[index] = !dt.template datetime_trunc<HOUR>();
|
||||
} else if (std::strncmp("minute", str_data, 6) == 0) {
|
||||
null_map[index] = !dt.template datetime_trunc<MINUTE>();
|
||||
} else if (std::strncmp("second", str_data, 6) == 0) {
|
||||
null_map[index] = !dt.template datetime_trunc<SECOND>();
|
||||
} else {
|
||||
null_map[index] = 1;
|
||||
}
|
||||
res[index] = binary_cast<DateValueType, ArgType>(dt);
|
||||
}
|
||||
};
|
||||
|
||||
class FromDays : public IFunction {
|
||||
@ -1263,6 +1255,17 @@ public:
|
||||
return Impl::get_return_type_impl(arguments);
|
||||
}
|
||||
|
||||
Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override {
|
||||
if constexpr (std::is_same_v<Impl, DateTrunc<DataTypeDate>> ||
|
||||
std::is_same_v<Impl, DateTrunc<DataTypeDateV2>> ||
|
||||
std::is_same_v<Impl, DateTrunc<DataTypeDateTime>> ||
|
||||
std::is_same_v<Impl, DateTrunc<DataTypeDateTimeV2>>) {
|
||||
return Impl::open(context, scope);
|
||||
} else {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: add function below when we fixed be-ut.
|
||||
//ColumnNumbers get_arguments_that_are_always_constant() const override { return {1}; }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user