diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index c19f71e4c6..fd3f02e599 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -20,6 +20,7 @@ set(LIBRARY_OUTPUT_PATH "${BUILD_DIR}/src/vec") set(EXECUTABLE_OUTPUT_PATH "${BUILD_DIR}/src/vec") set(VEC_FILES + aggregate_functions/aggregate_function_retention.cpp aggregate_functions/aggregate_function_window_funnel.cpp aggregate_functions/aggregate_function_avg.cpp aggregate_functions/aggregate_function_collect.cpp diff --git a/be/src/vec/aggregate_functions/aggregate_function_retention.cpp b/be/src/vec/aggregate_functions/aggregate_function_retention.cpp new file mode 100644 index 0000000000..a84fa3bbe0 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_retention.cpp @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/aggregate_functions/aggregate_function_retention.h" + +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" + +namespace doris::vectorized { + +AggregateFunctionPtr create_aggregate_function_retention(const std::string& name, + const DataTypes& argument_types, + const Array& parameters, + const bool result_is_nullable) { + return std::make_shared(argument_types); +} + +void register_aggregate_function_retention(AggregateFunctionSimpleFactory& factory) { + factory.register_function("retention", create_aggregate_function_retention, false); +} +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_retention.h b/be/src/vec/aggregate_functions/aggregate_function_retention.h new file mode 100644 index 0000000000..7d667ec671 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_retention.h @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/AggregateFunctionRetention.h +// and modified by Doris + +#pragma once + +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column_array.h" +#include "vec/columns/columns_number.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/io/var_int.h" + +namespace doris::vectorized { +struct RetentionState { + static constexpr size_t MAX_EVENTS = 32; + uint8_t events[MAX_EVENTS] = {0}; + + RetentionState() {} + + void reset() { + for (int64_t i = 0; i < MAX_EVENTS; i++) { + events[i] = 0; + } + } + + void set(int event) { events[event] = 1; } + + void merge(const RetentionState& other) { + for (int64_t i = 0; i < MAX_EVENTS; i++) { + events[i] |= other.events[i]; + } + } + + void write(BufferWritable& out) const { + int64_t serialized_events = 0; + for (int64_t i = 0; i < MAX_EVENTS; i++) { + serialized_events |= events[i]; + serialized_events <<= 1; + } + write_var_int(serialized_events, out); + } + + void read(BufferReadable& in) { + int64_t serialized_events = 0; + uint64_t u_serialized_events = 0; + read_var_int(serialized_events, in); + u_serialized_events = serialized_events; + + u_serialized_events >>= 1; + for (int64_t i = MAX_EVENTS - 1; i >= 0; i--) { + events[i] = (uint8)(1 & u_serialized_events); + u_serialized_events >>= 1; + } + } + + void insert_result_into(IColumn& to, size_t events_size, const uint8_t* events) const { + auto& data_to = assert_cast(to).get_data(); + + ColumnArray::Offset64 current_offset = data_to.size(); + data_to.resize(current_offset + events_size); + + bool first_flag = events[0]; + data_to[current_offset] = first_flag; + ++current_offset; + + for (size_t i = 1; i < events_size; ++i) { + data_to[current_offset] = (first_flag && events[i]); + ++current_offset; + } + } +}; + +class AggregateFunctionRetention + : public IAggregateFunctionDataHelper { +public: + AggregateFunctionRetention(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper( + argument_types_, {}) {} + + String get_name() const override { return "retention"; } + + DataTypePtr get_return_type() const override { + return std::make_shared(make_nullable(std::make_shared())); + } + + void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } + void add(AggregateDataPtr __restrict place, const IColumn** columns, const size_t row_num, + Arena*) const override { + for (int i = 0; i < get_argument_types().size(); i++) { + auto event = assert_cast*>(columns[i])->get_data()[row_num]; + if (event) { + this->data(place).set(i); + } + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + this->data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + this->data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + auto& to_arr = assert_cast(to); + auto& to_nested_col = to_arr.get_data(); + if (to_nested_col.is_nullable()) { + auto col_null = reinterpret_cast(&to_nested_col); + this->data(place).insert_result_into(col_null->get_nested_column(), + get_argument_types().size(), + this->data(place).events); + col_null->get_null_map_data().resize_fill(col_null->get_nested_column().size(), 0); + } else { + this->data(place).insert_result_into(to_nested_col, get_argument_types().size(), + this->data(place).events); + } + to_arr.get_offsets().push_back(to_nested_col.size()); + } +}; +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index 79d21985bc..9d4db9182e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -48,6 +48,7 @@ void register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFa void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_retention(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& factory); @@ -74,6 +75,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_percentile(instance); register_aggregate_function_percentile_approx(instance); register_aggregate_function_window_funnel(instance); + register_aggregate_function_retention(instance); register_aggregate_function_orthogonal_bitmap(instance); register_aggregate_function_collect_list(instance); diff --git a/be/test/CMakeLists.txt b/be/test/CMakeLists.txt index e21ff28f6b..651a7678f2 100644 --- a/be/test/CMakeLists.txt +++ b/be/test/CMakeLists.txt @@ -329,6 +329,7 @@ set(VEC_TEST_FILES vec/aggregate_functions/agg_test.cpp vec/aggregate_functions/agg_min_max_test.cpp vec/aggregate_functions/vec_window_funnel_test.cpp + vec/aggregate_functions/vec_retention_test.cpp vec/aggregate_functions/agg_min_max_by_test.cpp vec/core/block_test.cpp vec/core/column_array_test.cpp diff --git a/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp b/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp index 9c82214bd2..2008ce2d87 100644 --- a/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp +++ b/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp @@ -86,11 +86,11 @@ TEST_P(AggMinMaxByTest, min_max_by_test) { agg_function->insert_result_into(place, ans); if (i == 0) { // Key type is int32. - ASSERT_EQ(min_max_by_type == "max_by" ? 0 : agg_test_batch_size - 1, + EXPECT_EQ(min_max_by_type == "max_by" ? 0 : agg_test_batch_size - 1, ans.get_element(0)); } else { // Key type is string. - ASSERT_EQ(min_max_by_type == "max_by" ? max_pair.second : min_pair.second, + EXPECT_EQ(min_max_by_type == "max_by" ? max_pair.second : min_pair.second, ans.get_element(0)); } agg_function->destroy(place); diff --git a/be/test/vec/aggregate_functions/vec_retention_test.cpp b/be/test/vec/aggregate_functions/vec_retention_test.cpp new file mode 100644 index 0000000000..223d3192d2 --- /dev/null +++ b/be/test/vec/aggregate_functions/vec_retention_test.cpp @@ -0,0 +1,284 @@ + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "common/logging.h" +#include "gtest/gtest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/aggregate_function_topn.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_vector.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_number.h" + +namespace doris::vectorized { + +void register_aggregate_function_retention(AggregateFunctionSimpleFactory& factory); + +class VRetentionTest : public testing::Test { +public: + AggregateFunctionPtr agg_function; + + VRetentionTest() {} + + void SetUp() { + AggregateFunctionSimpleFactory factory = AggregateFunctionSimpleFactory::instance(); + DataTypes data_types = { + std::make_shared(), + std::make_shared(), + std::make_shared(), + }; + Array array; + agg_function = factory.get("retention", data_types, array, false); + EXPECT_NE(agg_function, nullptr); + } + + void TearDown() {} +}; + +TEST_F(VRetentionTest, testEmpty) { + std::unique_ptr memory(new char[agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + agg_function->create(place); + + ColumnString buf; + VectorBufferWriter buf_writer(buf); + agg_function->serialize(place, buf_writer); + buf_writer.commit(); + LOG(INFO) << "buf size : " << buf.size(); + VectorBufferReader buf_reader(buf.get_data_at(0)); + agg_function->deserialize(place, buf_reader, nullptr); + + std::unique_ptr memory2(new char[agg_function->size_of_data()]); + AggregateDataPtr place2 = memory2.get(); + agg_function->create(place2); + + agg_function->merge(place, place2, nullptr); + auto column_result = + ColumnArray::create(((DataTypePtr)std::make_shared())->create_column()); + agg_function->insert_result_into(place, *column_result); + auto& result = assert_cast(assert_cast(*column_result).get_data()) + .get_data(); + for (int i = 0; i < result.size(); i++) { + EXPECT_EQ(result[i], 0); + } + + auto column_result2 = + ColumnArray::create(((DataTypePtr)std::make_shared())->create_column()); + agg_function->insert_result_into(place2, *column_result2); + auto& result2 = assert_cast(assert_cast(*column_result2).get_data()) + .get_data(); + for (int i = 0; i < result2.size(); i++) { + EXPECT_EQ(result2[i], 0); + } + + EXPECT_EQ(column_result2->get_offsets()[-1], 0); + EXPECT_EQ(column_result2->get_offsets()[0], 3); + EXPECT_EQ(column_result2->get_offsets().size(), 1); + agg_function->destroy(place); + agg_function->destroy(place2); +} + +TEST_F(VRetentionTest, testSample) { + const int batch_size = 4; + + auto column_event1 = ColumnVector::create(); + column_event1->insert(0); + column_event1->insert(1); + column_event1->insert(0); + column_event1->insert(0); + + auto column_event2 = ColumnVector::create(); + column_event2->insert(0); + column_event2->insert(0); + column_event2->insert(1); + column_event2->insert(0); + + auto column_event3 = ColumnVector::create(); + column_event3->insert(0); + column_event3->insert(0); + column_event3->insert(0); + column_event3->insert(1); + + std::unique_ptr memory(new char[agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + agg_function->create(place); + const IColumn* column[3] = {column_event1.get(), column_event2.get(), column_event3.get()}; + for (int i = 0; i < batch_size; i++) { + agg_function->add(place, column, i, nullptr); + } + + std::unique_ptr memory2(new char[agg_function->size_of_data()]); + AggregateDataPtr place2 = memory2.get(); + agg_function->create(place2); + + agg_function->merge(place2, place, nullptr); + + auto column_result2 = + ColumnArray::create(((DataTypePtr)std::make_shared())->create_column()); + agg_function->insert_result_into(place2, *column_result2); + auto& result2 = assert_cast(assert_cast(*column_result2).get_data()) + .get_data(); + for (int i = 0; i < result2.size(); i++) { + EXPECT_EQ(result2[i], 1); + } + + EXPECT_EQ(column_result2->get_offsets()[-1], 0); + EXPECT_EQ(column_result2->get_offsets()[0], 3); + EXPECT_EQ(column_result2->get_offsets().size(), 1); + agg_function->destroy(place2); +} + +TEST_F(VRetentionTest, testNoMerge) { + const int batch_size = 4; + + auto column_event1 = ColumnVector::create(); + column_event1->insert(0); + column_event1->insert(1); + column_event1->insert(0); + column_event1->insert(0); + + auto column_event2 = ColumnVector::create(); + column_event2->insert(0); + column_event2->insert(0); + column_event2->insert(1); + column_event2->insert(0); + + auto column_event3 = ColumnVector::create(); + column_event3->insert(0); + column_event3->insert(0); + column_event3->insert(0); + column_event3->insert(1); + + std::unique_ptr memory(new char[agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + agg_function->create(place); + const IColumn* column[3] = {column_event1.get(), column_event2.get(), column_event3.get()}; + for (int i = 0; i < batch_size; i++) { + agg_function->add(place, column, i, nullptr); + } + + auto column_result = + ColumnArray::create(((DataTypePtr)std::make_shared())->create_column()); + agg_function->insert_result_into(place, *column_result); + auto& result = assert_cast(assert_cast(*column_result).get_data()) + .get_data(); + for (int i = 0; i < result.size(); i++) { + EXPECT_EQ(result[i], 1); + } + EXPECT_EQ(column_result->get_offsets()[-1], 0); + EXPECT_EQ(column_result->get_offsets()[0], 3); + EXPECT_EQ(column_result->get_offsets().size(), 1); + agg_function->destroy(place); +} + +TEST_F(VRetentionTest, testSerialize) { + const int batch_size = 2; + + auto column_event1 = ColumnVector::create(); + column_event1->insert(0); + column_event1->insert(1); + + auto column_event2 = ColumnVector::create(); + column_event2->insert(0); + column_event2->insert(0); + + auto column_event3 = ColumnVector::create(); + column_event3->insert(0); + column_event3->insert(0); + + std::unique_ptr memory(new char[agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + agg_function->create(place); + const IColumn* column[3] = {column_event1.get(), column_event2.get(), column_event3.get()}; + for (int i = 0; i < batch_size; i++) { + agg_function->add(place, column, i, nullptr); + } + + ColumnString buf; + VectorBufferWriter buf_writer(buf); + agg_function->serialize(place, buf_writer); + buf_writer.commit(); + agg_function->destroy(place); + + std::unique_ptr memory2(new char[agg_function->size_of_data()]); + AggregateDataPtr place2 = memory2.get(); + agg_function->create(place2); + + VectorBufferReader buf_reader(buf.get_data_at(0)); + agg_function->deserialize(place2, buf_reader, nullptr); + + auto column_result = + ColumnArray::create(((DataTypePtr)std::make_shared())->create_column()); + agg_function->insert_result_into(place2, *column_result); + auto& result = assert_cast(assert_cast(*column_result).get_data()) + .get_data(); + for (int i = 0; i < result.size(); i++) { + if (i == 0) { + EXPECT_EQ(result[i], 1); + } else { + EXPECT_EQ(result[i], 0); + } + } + + auto column_event4 = ColumnVector::create(); + column_event4->insert(0); + column_event4->insert(0); + + auto column_event5 = ColumnVector::create(); + column_event5->insert(0); + column_event5->insert(1); + + auto column_event6 = ColumnVector::create(); + column_event6->insert(0); + column_event6->insert(0); + + std::unique_ptr memory3(new char[agg_function->size_of_data()]); + AggregateDataPtr place3 = memory3.get(); + agg_function->create(place3); + const IColumn* column2[3] = {column_event4.get(), column_event5.get(), column_event6.get()}; + for (int i = 0; i < batch_size; i++) { + agg_function->add(place3, column2, i, nullptr); + } + + agg_function->merge(place2, place3, nullptr); + + auto column_result2 = + ColumnArray::create(((DataTypePtr)std::make_shared())->create_column()); + agg_function->insert_result_into(place2, *column_result2); + auto& result2 = assert_cast(assert_cast(*column_result2).get_data()) + .get_data(); + for (int i = 0; i < result2.size(); i++) { + if (i == result2.size() - 1) { + EXPECT_EQ(result2[i], 0); + } else { + EXPECT_EQ(result2[i], 1); + } + } + + EXPECT_EQ(column_result2->get_offsets()[-1], 0); + EXPECT_EQ(column_result2->get_offsets()[0], 3); + EXPECT_EQ(column_result2->get_offsets().size(), 1); + + agg_function->destroy(place2); + agg_function->destroy(place3); +} +} // namespace doris::vectorized diff --git a/docs/en/docs/sql-manual/sql-functions/aggregate-functions/retention.md b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/retention.md new file mode 100644 index 0000000000..821fef13f1 --- /dev/null +++ b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/retention.md @@ -0,0 +1,136 @@ +--- +{ + "title": "RETENTION", + "language": "en" +} +--- + + + +## RETENTION +### Description +#### Syntax + +`retention(event1, event2, ... , eventN);` + +The `retention` function takes as arguments a set of conditions from 1 to 32 arguments of type `UInt8` that indicate whether a certain condition was met for the event. Any condition can be specified as an argument. + +The conditions, except the first, apply in pairs: the result of the second will be true if the first and second are true, of the third if the first and third are true, etc. + +#### Arguments + +`event` — An expression that returns a `UInt8` result (1 or 0). + +##### Returned value + +The array of 1 or 0. + +1 — Condition was met for the event. + +0 — Condition wasn’t met for the event. + +### example + +```sql +DROP TABLE IF EXISTS retention_test; + +CREATE TABLE retention_test( + `uid` int COMMENT 'user id', + `date` datetime COMMENT 'date time' + ) +DUPLICATE KEY(uid) +DISTRIBUTED BY HASH(uid) BUCKETS 3 +PROPERTIES ( + "replication_num" = "1" +); + +INSERT into retention_test (uid, date) values (0, '2022-10-12'), + (0, '2022-10-13'), + (0, '2022-10-14'), + (1, '2022-10-12'), + (1, '2022-10-13'), + (2, '2022-10-12'); + +SELECT * from retention_test; + ++------+---------------------+ +| uid | date | ++------+---------------------+ +| 0 | 2022-10-14 00:00:00 | +| 0 | 2022-10-13 00:00:00 | +| 0 | 2022-10-12 00:00:00 | +| 1 | 2022-10-13 00:00:00 | +| 1 | 2022-10-12 00:00:00 | +| 2 | 2022-10-12 00:00:00 | ++------+---------------------+ + +SELECT + uid, + retention(date = '2022-10-12') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; + ++------+------+ +| uid | r | ++------+------+ +| 0 | [1] | +| 1 | [1] | +| 2 | [1] | ++------+------+ + +SELECT + uid, + retention(date = '2022-10-12', date = '2022-10-13') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; + ++------+--------+ +| uid | r | ++------+--------+ +| 0 | [1, 1] | +| 1 | [1, 1] | +| 2 | [1, 0] | ++------+--------+ + +SELECT + uid, + retention(date = '2022-10-12', date = '2022-10-13', date = '2022-10-14') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; + ++------+-----------+ +| uid | r | ++------+-----------+ +| 0 | [1, 1, 1] | +| 1 | [1, 1, 0] | +| 2 | [1, 0, 0] | ++------+-----------+ + +``` + +### keywords + +RETENTION \ No newline at end of file diff --git a/docs/sidebars.json b/docs/sidebars.json index 3de9e228f1..97f7348ffb 100644 --- a/docs/sidebars.json +++ b/docs/sidebars.json @@ -429,7 +429,8 @@ "sql-manual/sql-functions/aggregate-functions/any_value", "sql-manual/sql-functions/aggregate-functions/var_samp", "sql-manual/sql-functions/aggregate-functions/approx_count_distinct", - "sql-manual/sql-functions/aggregate-functions/variance" + "sql-manual/sql-functions/aggregate-functions/variance", + "sql-manual/sql-functions/aggregate-functions/retention" ] }, { diff --git a/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/retention.md b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/retention.md new file mode 100644 index 0000000000..ab432ff08a --- /dev/null +++ b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/retention.md @@ -0,0 +1,136 @@ +--- +{ + "title": "RETENTION", + "language": "zh-CN" +} +--- + + + +## RETENTION +### description +#### Syntax + +`retention(event1, event2, ... , eventN);` + +留存函数将一组条件作为参数,类型为1到32个`UInt8`类型的参数,用来表示事件是否满足特定条件。 任何条件都可以指定为参数. + +除了第一个以外,条件成对适用:如果第一个和第二个是真的,第二个结果将是真的,如果第一个和第三个是真的,第三个结果将是真的,等等。 + +#### Arguments + +`event` — 返回`UInt8`结果(1或0)的表达式. + +##### Returned value + +由1和0组成的数组。 + +1 — 条件满足。 + +0 — 条件不满足 + +### example + +```sql +DROP TABLE IF EXISTS retention_test; + +CREATE TABLE retention_test( + `uid` int COMMENT 'user id', + `date` datetime COMMENT 'date time' + ) +DUPLICATE KEY(uid) +DISTRIBUTED BY HASH(uid) BUCKETS 3 +PROPERTIES ( + "replication_num" = "1" +); + +INSERT into retention_test (uid, date) values (0, '2022-10-12'), + (0, '2022-10-13'), + (0, '2022-10-14'), + (1, '2022-10-12'), + (1, '2022-10-13'), + (2, '2022-10-12'); + +SELECT * from retention_test; + ++------+---------------------+ +| uid | date | ++------+---------------------+ +| 0 | 2022-10-14 00:00:00 | +| 0 | 2022-10-13 00:00:00 | +| 0 | 2022-10-12 00:00:00 | +| 1 | 2022-10-13 00:00:00 | +| 1 | 2022-10-12 00:00:00 | +| 2 | 2022-10-12 00:00:00 | ++------+---------------------+ + +SELECT + uid, + retention(date = '2022-10-12') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; + ++------+------+ +| uid | r | ++------+------+ +| 0 | [1] | +| 1 | [1] | +| 2 | [1] | ++------+------+ + +SELECT + uid, + retention(date = '2022-10-12', date = '2022-10-13') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; + ++------+--------+ +| uid | r | ++------+--------+ +| 0 | [1, 1] | +| 1 | [1, 1] | +| 2 | [1, 0] | ++------+--------+ + +SELECT + uid, + retention(date = '2022-10-12', date = '2022-10-13', date = '2022-10-14') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; + ++------+-----------+ +| uid | r | ++------+-----------+ +| 0 | [1, 1, 1] | +| 1 | [1, 1, 0] | +| 2 | [1, 0, 0] | ++------+-----------+ + +``` + +### keywords + +RETENTION \ No newline at end of file diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index 0c55205743..bef48ed4a3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -77,7 +77,7 @@ public class FunctionCallExpr extends Expr { private static final ImmutableSet DECIMAL_WIDER_TYPE_SET = new ImmutableSortedSet.Builder(String.CASE_INSENSITIVE_ORDER) .add("sum").add("avg").add("multi_distinct_sum").build(); - private static final ImmutableSet DECIMAL_FUNCTION_SET = + private static final ImmutableSet DECIMAL_FUNCTION_SET = new ImmutableSortedSet.Builder<>(String.CASE_INSENSITIVE_ORDER) .addAll(DECIMAL_SAME_TYPE_SET) .addAll(DECIMAL_WIDER_TYPE_SET) @@ -173,11 +173,11 @@ public class FunctionCallExpr extends Expr { if (!orderByElements.isEmpty()) { if (!VectorizedUtil.isVectorized()) { throw new AnalysisException( - "ORDER BY for arguments only support in vec exec engine"); + "ORDER BY for arguments only support in vec exec engine"); } else if (!AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains( fnName.getFunction().toLowerCase())) { throw new AnalysisException( - "ORDER BY not support for the function:" + fnName.getFunction().toLowerCase()); + "ORDER BY not support for the function:" + fnName.getFunction().toLowerCase()); } } setChildren(); @@ -845,6 +845,7 @@ public class FunctionCallExpr extends Expr { /** * This analyzeImp used for DefaultValueExprDef * to generate a builtinFunction. + * * @throws AnalysisException */ public void analyzeImplForDefaultValue(Type type) throws AnalysisException { @@ -907,7 +908,7 @@ public class FunctionCallExpr extends Expr { if (!VectorizedUtil.isVectorized()) { type = getChild(0).type.getMaxResolutionType(); } - fn = getBuiltinFunction(fnName.getFunction(), new Type[]{type}, + fn = getBuiltinFunction(fnName.getFunction(), new Type[] {type}, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); } else if (fnName.getFunction().equalsIgnoreCase("count_distinct")) { Type compatibleType = this.children.get(0).getType(); @@ -920,7 +921,7 @@ public class FunctionCallExpr extends Expr { } } - fn = getBuiltinFunction(fnName.getFunction(), new Type[]{compatibleType}, + fn = getBuiltinFunction(fnName.getFunction(), new Type[] {compatibleType}, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); } else if (fnName.getFunction().equalsIgnoreCase(FunctionSet.WINDOW_FUNNEL)) { if (fnParams.exprs() == null || fnParams.exprs().size() < 4) { @@ -962,6 +963,21 @@ public class FunctionCallExpr extends Expr { // cast date to datetime uncheckedCastChild(ScalarType.DATETIMEV2, 2); } + } else if (fnName.getFunction().equalsIgnoreCase(FunctionSet.RETENTION)) { + if (this.children.isEmpty()) { + throw new AnalysisException("The " + fnName + " function must have at least one param"); + } + + Type[] childTypes = new Type[children.size()]; + for (int i = 0; i < children.size(); i++) { + if (children.get(i).type != Type.BOOLEAN) { + throw new AnalysisException("All params of " + + fnName + " function must be boolean"); + } + childTypes[i] = children.get(i).type; + } + fn = getBuiltinFunction(fnName.getFunction(), childTypes, + Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); } else if (fnName.getFunction().equalsIgnoreCase("if")) { Type[] childTypes = collectChildReturnTypes(); Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[1], childTypes[2], true); @@ -980,7 +996,7 @@ public class FunctionCallExpr extends Expr { Type[] newChildTypes = new Type[children.size() - orderByElements.size()]; System.arraycopy(childTypes, 0, newChildTypes, 0, newChildTypes.length); fn = getBuiltinFunction(fnName.getFunction(), newChildTypes, - Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); + Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); } else { // now first find table function in table function sets if (isTableFnCall) { @@ -1457,7 +1473,7 @@ public class FunctionCallExpr extends Expr { || fnName.getFunction().equalsIgnoreCase("avg") || fnName.getFunction().equalsIgnoreCase("weekOfYear")) { Type childType = getChild(0).type; - fn = getBuiltinFunction(fnName.getFunction(), new Type[]{childType}, + fn = getBuiltinFunction(fnName.getFunction(), new Type[] {childType}, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); type = fn.getReturnType(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index a58097e9eb..ef9dba15cc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -50,7 +50,8 @@ public class AggregateFunction extends Function { "dense_rank", "multi_distinct_count", "multi_distinct_sum", "hll_union_agg", "hll_union", "bitmap_union", "bitmap_intersect", "orthogonal_bitmap_intersect", "orthogonal_bitmap_intersect_count", "intersect_count", "orthogonal_bitmap_union_count", FunctionSet.COUNT, "approx_count_distinct", "ndv", - FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", FunctionSet.WINDOW_FUNNEL); + FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", FunctionSet.WINDOW_FUNNEL, + FunctionSet.RETENTION); public static ImmutableSet ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET = ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 23af2f806f..9be288b4ce 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1390,6 +1390,8 @@ public class FunctionSet { public static final String COUNT = "count"; public static final String WINDOW_FUNNEL = "window_funnel"; + public static final String RETENTION = "retention"; + // Populate all the aggregate builtins in the catalog. // null symbols indicate the function does not need that step of the evaluation. // An empty symbol indicates a TODO for the BE to implement the function. @@ -1466,6 +1468,21 @@ public class FunctionSet { "", true, false, true, true)); + // retention vectorization + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.RETENTION, + Lists.newArrayList(Type.BOOLEAN), + Type.ARRAY, + Type.VARCHAR, + true, + "", + "", + "", + "", + "", + "", + "", + true, false, true, true)); + for (Type t : Type.getSupportedTypes()) { if (t.isNull()) { continue; // NULL is handled through type promotion. diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/AggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/AggregateTest.java index a5b19b8bfc..40d745ba1c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/analysis/AggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/AggregateTest.java @@ -88,6 +88,78 @@ public class AggregateTest extends TestWithFeService { } while (false); } + + @Test + public void testRetentionAnalysisException() throws Exception { + ConnectContext ctx = UtFrameUtils.createDefaultCtx(); + + // normal. + do { + String query = "select empid, retention(empid = 1, empid = 2) from " + + DB_NAME + "." + TABLE_NAME + " group by empid"; + try { + UtFrameUtils.parseAndAnalyzeStmt(query, ctx); + } catch (Exception e) { + Assert.fail("must be AnalysisException."); + } + } while (false); + + do { + String query = "select empid, retention(empid = 1, empid = 2, empid = 3, empid = 4) from " + + DB_NAME + "." + TABLE_NAME + " group by empid"; + try { + UtFrameUtils.parseAndAnalyzeStmt(query, ctx); + } catch (Exception e) { + Assert.fail("must be AnalysisException."); + } + } while (false); + + // less argument. + do { + String query = "select empid, retention() from " + + DB_NAME + "." + TABLE_NAME + " group by empid"; + try { + UtFrameUtils.parseAndAnalyzeStmt(query, ctx); + } catch (AnalysisException e) { + Assert.assertTrue(e.getMessage().contains("function must have at least one param")); + break; + } catch (Exception e) { + Assert.fail("must be AnalysisException."); + } + Assert.fail("must be AnalysisException."); + } while (false); + + // argument with wrong type. + do { + String query = "select empid, retention('xx', empid = 1) from " + + DB_NAME + "." + TABLE_NAME + " group by empid"; + try { + UtFrameUtils.parseAndAnalyzeStmt(query, ctx); + } catch (AnalysisException e) { + Assert.assertTrue(e.getMessage().contains("All params of retention function must be boolean")); + break; + } catch (Exception e) { + Assert.fail("must be AnalysisException."); + } + Assert.fail("must be AnalysisException."); + } while (false); + + do { + String query = "select empid, retention(1) from " + + DB_NAME + "." + TABLE_NAME + " group by empid"; + try { + UtFrameUtils.parseAndAnalyzeStmt(query, ctx); + } catch (AnalysisException e) { + Assert.assertTrue(e.getMessage().contains("All params of retention function must be boolean")); + break; + } catch (Exception e) { + Assert.fail("must be AnalysisException."); + } + Assert.fail("must be AnalysisException."); + } while (false); + + } + @Test public void testWindowFunnelAnalysisException() throws Exception { ConnectContext ctx = UtFrameUtils.createDefaultCtx(); diff --git a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_retention.out b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_retention.out new file mode 100644 index 0000000000..97f1b56680 --- /dev/null +++ b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_retention.out @@ -0,0 +1,59 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !test_aggregate_retention -- +0 + +-- !test_aggregate_retention_2 -- +0 + +-- !test_aggregate_retention_3 -- +6 + +-- !test_aggregate_retention_4 -- +0 2022-10-14T00:00 +0 2022-10-13T00:00 +0 2022-10-12T00:00 +1 2022-10-13T00:00 +1 2022-10-12T00:00 +2 2022-10-12T00:00 + +-- !test_aggregate_retention_5 -- +0 [1] +1 [1] +2 [1] + +-- !test_aggregate_retention_6 -- +0 [1, 1] +1 [1, 1] +2 [1, 0] + +-- !test_aggregate_retention_7 -- +0 [1, 1, 1] +1 [1, 1, 0] +2 [1, 0, 0] + +-- !test_aggregate_retention_8 -- +0 + +-- !test_aggregate_retention_9 -- +0 2022-10-14T00:00 +0 2022-10-13T00:00 +0 2022-10-12T00:00 +1 2022-10-13T00:00 +1 2022-10-12T00:00 +2 2022-10-12T00:00 + +-- !test_aggregate_retention_10 -- +0 [1] +1 [1] +2 [1] + +-- !test_aggregate_retention_11 -- +0 [1, 1] +1 [1, 1] +2 [1, 0] + +-- !test_aggregate_retention_12 -- +0 [1, 1, 1] +1 [1, 1, 0] +2 [1, 0, 0] + diff --git a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_retention.sql b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_retention.sql new file mode 100644 index 0000000000..d96c35b6b8 --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_retention.sql @@ -0,0 +1,72 @@ +DROP TABLE IF EXISTS retention_test; + +CREATE TABLE retention_test( + `uid` int COMMENT 'user id', + `date` datetime COMMENT 'date time' + ) +DUPLICATE KEY(uid) +DISTRIBUTED BY HASH(uid) BUCKETS 3 +PROPERTIES ( + "replication_num" = "1" +); + +INSERT into retention_test (uid, date) values (0, '2022-10-12'), + (0, '2022-10-13'), + (0, '2022-10-14'), + (1, '2022-10-12'), + (1, '2022-10-13'), + (2, '2022-10-12'); + +SELECT * from retention_test ORDER BY uid; + +SELECT + uid, + retention(date = '2022-10-12') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; + +SELECT + uid, + retention(date = '2022-10-12', date = '2022-10-13') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; + +SELECT + uid, + retention(date = '2022-10-12', date = '2022-10-13', date = '2022-10-14') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; + +SET parallel_fragment_exec_instance_num=4; + +SELECT * from retention_test ORDER BY uid; + +SELECT + uid, + retention(date = '2022-10-12') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; + +SELECT + uid, + retention(date = '2022-10-12', date = '2022-10-13') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; + +SELECT + uid, + retention(date = '2022-10-12', date = '2022-10-13', date = '2022-10-14') + AS r + FROM retention_test + GROUP BY uid + ORDER BY uid ASC; \ No newline at end of file