[improvement](function) improve date_trunc function performance when timeunit is const (#25824)

this PR #22602 have check function.
only support date_trunc(column, const), so the second must be const literal
and no need to check time unit every row.
This commit is contained in:
zhangstar333
2023-10-26 09:51:21 +08:00
committed by GitHub
parent 77f727e0a1
commit da4de17d5c
3 changed files with 142 additions and 96 deletions

View File

@ -24,6 +24,7 @@
#include <cstring>
#include <memory>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
@ -378,6 +379,12 @@ private:
}
};
struct DateTruncState {
using Callback_function =
std::function<void(const ColumnPtr&, ColumnPtr& res, NullMap& null_map, size_t)>;
Callback_function callback_function;
};
template <typename DateType>
struct DateTrunc {
static constexpr auto name = "date_trunc";
@ -396,88 +403,73 @@ struct DateTrunc {
return make_nullable(std::make_shared<DateType>());
}
static Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
if (scope != FunctionContext::THREAD_LOCAL) {
return Status::OK();
}
if (!context->is_col_constant(1)) {
return Status::InvalidArgument(
"date_trunc function of time unit argument must be constant.");
}
const auto& data_str = context->get_constant_col(1)->column_ptr->get_data_at(0);
std::string lower_str(data_str.data, data_str.size);
std::transform(lower_str.begin(), lower_str.end(), lower_str.begin(),
[](unsigned char c) { return std::tolower(c); });
std::shared_ptr<DateTruncState> state = std::make_shared<DateTruncState>();
if (std::strncmp("year", lower_str.data(), 4) == 0) {
state->callback_function = &execute_impl_right_const<TimeUnit::YEAR>;
} else if (std::strncmp("quarter", lower_str.data(), 7) == 0) {
state->callback_function = &execute_impl_right_const<TimeUnit::QUARTER>;
} else if (std::strncmp("month", lower_str.data(), 5) == 0) {
state->callback_function = &execute_impl_right_const<TimeUnit::MONTH>;
} else if (std::strncmp("week", lower_str.data(), 4) == 0) {
state->callback_function = &execute_impl_right_const<TimeUnit::WEEK>;
} else if (std::strncmp("day", lower_str.data(), 3) == 0) {
state->callback_function = &execute_impl_right_const<TimeUnit::DAY>;
} else if (std::strncmp("hour", lower_str.data(), 4) == 0) {
state->callback_function = &execute_impl_right_const<TimeUnit::HOUR>;
} else if (std::strncmp("minute", lower_str.data(), 6) == 0) {
state->callback_function = &execute_impl_right_const<TimeUnit::MINUTE>;
} else if (std::strncmp("second", lower_str.data(), 6) == 0) {
state->callback_function = &execute_impl_right_const<TimeUnit::SECOND>;
} else {
return Status::RuntimeError(
"Illegal second argument column of function date_trunc. now only support "
"[second,minute,hour,day,week,month,quarter,year]");
}
context->set_function_state(scope, state);
return Status::OK();
}
static Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) {
DCHECK_EQ(arguments.size(), 2);
auto null_map = ColumnUInt8::create(input_rows_count, 0);
const auto& col0 = block.get_by_position(arguments[0]).column;
bool col_const[2] = {is_column_const(*col0)};
ColumnPtr argument_columns[2] = {
col_const[0] ? static_cast<const ColumnConst&>(*col0).convert_to_full_column()
: col0};
std::tie(argument_columns[1], col_const[1]) =
unpack_if_const(block.get_by_position(arguments[1]).column);
auto datetime_column = static_cast<const ColumnType*>(argument_columns[0].get());
auto str_column = static_cast<const ColumnString*>(argument_columns[1].get());
auto& rdata = str_column->get_chars();
auto& roffsets = str_column->get_offsets();
ColumnPtr res = ColumnType::create();
if (col_const[1]) {
execute_impl_right_const(
datetime_column->get_data(), str_column->get_data_at(0),
static_cast<ColumnType*>(res->assume_mutable().get())->get_data(),
null_map->get_data(), input_rows_count);
} else {
execute_impl(datetime_column->get_data(), rdata, roffsets,
static_cast<ColumnType*>(res->assume_mutable().get())->get_data(),
null_map->get_data(), input_rows_count);
}
const auto& datetime_column =
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
ColumnPtr res = ColumnType::create(input_rows_count);
auto* state = reinterpret_cast<DateTruncState*>(
context->get_function_state(FunctionContext::THREAD_LOCAL));
DCHECK(state != nullptr);
state->callback_function(datetime_column, res, null_map->get_data(), input_rows_count);
block.get_by_position(result).column = ColumnNullable::create(res, std::move(null_map));
return Status::OK();
}
private:
static void execute_impl(const PaddedPODArray<ArgType>& ldata, const ColumnString::Chars& rdata,
const ColumnString::Offsets& roffsets, PaddedPODArray<ArgType>& res,
NullMap& null_map, size_t input_rows_count) {
res.resize(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i) {
auto dt = binary_cast<ArgType, DateValueType>(ldata[i]);
const char* str_data = reinterpret_cast<const char*>(&rdata[roffsets[i - 1]]);
_execute_inner_loop(dt, str_data, res, null_map, i);
}
}
static void execute_impl_right_const(const PaddedPODArray<ArgType>& ldata,
const StringRef& rdata, PaddedPODArray<ArgType>& res,
template <TimeUnit Unit>
static void execute_impl_right_const(const ColumnPtr& datetime_column, ColumnPtr& result_column,
NullMap& null_map, size_t input_rows_count) {
res.resize(input_rows_count);
std::string lower_str(rdata.data, rdata.size);
std::transform(lower_str.begin(), lower_str.end(), lower_str.begin(),
[](unsigned char c) { return std::tolower(c); });
auto& data = static_cast<const ColumnType*>(datetime_column.get())->get_data();
auto& res = static_cast<ColumnType*>(result_column->assume_mutable().get())->get_data();
for (size_t i = 0; i < input_rows_count; ++i) {
auto dt = binary_cast<ArgType, DateValueType>(ldata[i]);
_execute_inner_loop(dt, lower_str.data(), res, null_map, i);
auto dt = binary_cast<ArgType, DateValueType>(data[i]);
null_map[i] = !dt.template datetime_trunc<Unit>();
res[i] = binary_cast<DateValueType, ArgType>(dt);
}
}
template <typename T>
static void _execute_inner_loop(T& dt, const char* str_data, PaddedPODArray<ArgType>& res,
NullMap& null_map, size_t index) {
if (std::strncmp("year", str_data, 4) == 0) {
null_map[index] = !dt.template datetime_trunc<YEAR>();
} else if (std::strncmp("quarter", str_data, 7) == 0) {
null_map[index] = !dt.template datetime_trunc<QUARTER>();
} else if (std::strncmp("month", str_data, 5) == 0) {
null_map[index] = !dt.template datetime_trunc<MONTH>();
} else if (std::strncmp("week", str_data, 4) == 0) {
null_map[index] = !dt.template datetime_trunc<WEEK>();
} else if (std::strncmp("day", str_data, 3) == 0) {
null_map[index] = !dt.template datetime_trunc<DAY>();
} else if (std::strncmp("hour", str_data, 4) == 0) {
null_map[index] = !dt.template datetime_trunc<HOUR>();
} else if (std::strncmp("minute", str_data, 6) == 0) {
null_map[index] = !dt.template datetime_trunc<MINUTE>();
} else if (std::strncmp("second", str_data, 6) == 0) {
null_map[index] = !dt.template datetime_trunc<SECOND>();
} else {
null_map[index] = 1;
}
res[index] = binary_cast<DateValueType, ArgType>(dt);
}
};
class FromDays : public IFunction {
@ -1263,6 +1255,17 @@ public:
return Impl::get_return_type_impl(arguments);
}
Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override {
if constexpr (std::is_same_v<Impl, DateTrunc<DataTypeDate>> ||
std::is_same_v<Impl, DateTrunc<DataTypeDateV2>> ||
std::is_same_v<Impl, DateTrunc<DataTypeDateTime>> ||
std::is_same_v<Impl, DateTrunc<DataTypeDateTimeV2>>) {
return Impl::open(context, scope);
} else {
return Status::OK();
}
}
//TODO: add function below when we fixed be-ut.
//ColumnNumbers get_arguments_that_are_always_constant() const override { return {1}; }