[Feature](Retention) support retention function (#13056)

This commit is contained in:
abmdocrt
2022-10-17 11:00:47 +08:00
committed by GitHub
parent 6ea9a65bb6
commit 045bccdbea
16 changed files with 990 additions and 11 deletions

View File

@ -20,6 +20,7 @@ set(LIBRARY_OUTPUT_PATH "${BUILD_DIR}/src/vec")
set(EXECUTABLE_OUTPUT_PATH "${BUILD_DIR}/src/vec")
set(VEC_FILES
aggregate_functions/aggregate_function_retention.cpp
aggregate_functions/aggregate_function_window_funnel.cpp
aggregate_functions/aggregate_function_avg.cpp
aggregate_functions/aggregate_function_collect.cpp

View File

@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "vec/aggregate_functions/aggregate_function_retention.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/factory_helpers.h"
#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
AggregateFunctionPtr create_aggregate_function_retention(const std::string& name,
const DataTypes& argument_types,
const Array& parameters,
const bool result_is_nullable) {
return std::make_shared<AggregateFunctionRetention>(argument_types);
}
void register_aggregate_function_retention(AggregateFunctionSimpleFactory& factory) {
factory.register_function("retention", create_aggregate_function_retention, false);
}
} // namespace doris::vectorized

View File

@ -0,0 +1,145 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/AggregateFunctionRetention.h
// and modified by Doris
#pragma once
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column_array.h"
#include "vec/columns/columns_number.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/io/var_int.h"
namespace doris::vectorized {
struct RetentionState {
static constexpr size_t MAX_EVENTS = 32;
uint8_t events[MAX_EVENTS] = {0};
RetentionState() {}
void reset() {
for (int64_t i = 0; i < MAX_EVENTS; i++) {
events[i] = 0;
}
}
void set(int event) { events[event] = 1; }
void merge(const RetentionState& other) {
for (int64_t i = 0; i < MAX_EVENTS; i++) {
events[i] |= other.events[i];
}
}
void write(BufferWritable& out) const {
int64_t serialized_events = 0;
for (int64_t i = 0; i < MAX_EVENTS; i++) {
serialized_events |= events[i];
serialized_events <<= 1;
}
write_var_int(serialized_events, out);
}
void read(BufferReadable& in) {
int64_t serialized_events = 0;
uint64_t u_serialized_events = 0;
read_var_int(serialized_events, in);
u_serialized_events = serialized_events;
u_serialized_events >>= 1;
for (int64_t i = MAX_EVENTS - 1; i >= 0; i--) {
events[i] = (uint8)(1 & u_serialized_events);
u_serialized_events >>= 1;
}
}
void insert_result_into(IColumn& to, size_t events_size, const uint8_t* events) const {
auto& data_to = assert_cast<ColumnUInt8&>(to).get_data();
ColumnArray::Offset64 current_offset = data_to.size();
data_to.resize(current_offset + events_size);
bool first_flag = events[0];
data_to[current_offset] = first_flag;
++current_offset;
for (size_t i = 1; i < events_size; ++i) {
data_to[current_offset] = (first_flag && events[i]);
++current_offset;
}
}
};
class AggregateFunctionRetention
: public IAggregateFunctionDataHelper<RetentionState, AggregateFunctionRetention> {
public:
AggregateFunctionRetention(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<RetentionState, AggregateFunctionRetention>(
argument_types_, {}) {}
String get_name() const override { return "retention"; }
DataTypePtr get_return_type() const override {
return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeUInt8>()));
}
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
void add(AggregateDataPtr __restrict place, const IColumn** columns, const size_t row_num,
Arena*) const override {
for (int i = 0; i < get_argument_types().size(); i++) {
auto event = assert_cast<const ColumnVector<UInt8>*>(columns[i])->get_data()[row_num];
if (event) {
this->data(place).set(i);
}
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
this->data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& to_arr = assert_cast<ColumnArray&>(to);
auto& to_nested_col = to_arr.get_data();
if (to_nested_col.is_nullable()) {
auto col_null = reinterpret_cast<ColumnNullable*>(&to_nested_col);
this->data(place).insert_result_into(col_null->get_nested_column(),
get_argument_types().size(),
this->data(place).events);
col_null->get_null_map_data().resize_fill(col_null->get_nested_column().size(), 0);
} else {
this->data(place).insert_result_into(to_nested_col, get_argument_types().size(),
this->data(place).events);
}
to_arr.get_offsets().push_back(to_nested_col.size());
}
};
} // namespace doris::vectorized

View File

@ -48,6 +48,7 @@ void register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFa
void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_retention(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& factory);
@ -74,6 +75,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_percentile(instance);
register_aggregate_function_percentile_approx(instance);
register_aggregate_function_window_funnel(instance);
register_aggregate_function_retention(instance);
register_aggregate_function_orthogonal_bitmap(instance);
register_aggregate_function_collect_list(instance);

View File

@ -329,6 +329,7 @@ set(VEC_TEST_FILES
vec/aggregate_functions/agg_test.cpp
vec/aggregate_functions/agg_min_max_test.cpp
vec/aggregate_functions/vec_window_funnel_test.cpp
vec/aggregate_functions/vec_retention_test.cpp
vec/aggregate_functions/agg_min_max_by_test.cpp
vec/core/block_test.cpp
vec/core/column_array_test.cpp

View File

@ -86,11 +86,11 @@ TEST_P(AggMinMaxByTest, min_max_by_test) {
agg_function->insert_result_into(place, ans);
if (i == 0) {
// Key type is int32.
ASSERT_EQ(min_max_by_type == "max_by" ? 0 : agg_test_batch_size - 1,
EXPECT_EQ(min_max_by_type == "max_by" ? 0 : agg_test_batch_size - 1,
ans.get_element(0));
} else {
// Key type is string.
ASSERT_EQ(min_max_by_type == "max_by" ? max_pair.second : min_pair.second,
EXPECT_EQ(min_max_by_type == "max_by" ? max_pair.second : min_pair.second,
ans.get_element(0));
}
agg_function->destroy(place);

View File

@ -0,0 +1,284 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <gtest/gtest.h>
#include "common/logging.h"
#include "gtest/gtest.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/aggregate_function_topn.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_vector.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_number.h"
namespace doris::vectorized {
void register_aggregate_function_retention(AggregateFunctionSimpleFactory& factory);
class VRetentionTest : public testing::Test {
public:
AggregateFunctionPtr agg_function;
VRetentionTest() {}
void SetUp() {
AggregateFunctionSimpleFactory factory = AggregateFunctionSimpleFactory::instance();
DataTypes data_types = {
std::make_shared<DataTypeUInt8>(),
std::make_shared<DataTypeUInt8>(),
std::make_shared<DataTypeUInt8>(),
};
Array array;
agg_function = factory.get("retention", data_types, array, false);
EXPECT_NE(agg_function, nullptr);
}
void TearDown() {}
};
TEST_F(VRetentionTest, testEmpty) {
std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
agg_function->create(place);
ColumnString buf;
VectorBufferWriter buf_writer(buf);
agg_function->serialize(place, buf_writer);
buf_writer.commit();
LOG(INFO) << "buf size : " << buf.size();
VectorBufferReader buf_reader(buf.get_data_at(0));
agg_function->deserialize(place, buf_reader, nullptr);
std::unique_ptr<char[]> memory2(new char[agg_function->size_of_data()]);
AggregateDataPtr place2 = memory2.get();
agg_function->create(place2);
agg_function->merge(place, place2, nullptr);
auto column_result =
ColumnArray::create(((DataTypePtr)std::make_shared<DataTypeUInt8>())->create_column());
agg_function->insert_result_into(place, *column_result);
auto& result = assert_cast<ColumnUInt8&>(assert_cast<ColumnArray&>(*column_result).get_data())
.get_data();
for (int i = 0; i < result.size(); i++) {
EXPECT_EQ(result[i], 0);
}
auto column_result2 =
ColumnArray::create(((DataTypePtr)std::make_shared<DataTypeUInt8>())->create_column());
agg_function->insert_result_into(place2, *column_result2);
auto& result2 = assert_cast<ColumnUInt8&>(assert_cast<ColumnArray&>(*column_result2).get_data())
.get_data();
for (int i = 0; i < result2.size(); i++) {
EXPECT_EQ(result2[i], 0);
}
EXPECT_EQ(column_result2->get_offsets()[-1], 0);
EXPECT_EQ(column_result2->get_offsets()[0], 3);
EXPECT_EQ(column_result2->get_offsets().size(), 1);
agg_function->destroy(place);
agg_function->destroy(place2);
}
TEST_F(VRetentionTest, testSample) {
const int batch_size = 4;
auto column_event1 = ColumnVector<UInt8>::create();
column_event1->insert(0);
column_event1->insert(1);
column_event1->insert(0);
column_event1->insert(0);
auto column_event2 = ColumnVector<UInt8>::create();
column_event2->insert(0);
column_event2->insert(0);
column_event2->insert(1);
column_event2->insert(0);
auto column_event3 = ColumnVector<UInt8>::create();
column_event3->insert(0);
column_event3->insert(0);
column_event3->insert(0);
column_event3->insert(1);
std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
agg_function->create(place);
const IColumn* column[3] = {column_event1.get(), column_event2.get(), column_event3.get()};
for (int i = 0; i < batch_size; i++) {
agg_function->add(place, column, i, nullptr);
}
std::unique_ptr<char[]> memory2(new char[agg_function->size_of_data()]);
AggregateDataPtr place2 = memory2.get();
agg_function->create(place2);
agg_function->merge(place2, place, nullptr);
auto column_result2 =
ColumnArray::create(((DataTypePtr)std::make_shared<DataTypeUInt8>())->create_column());
agg_function->insert_result_into(place2, *column_result2);
auto& result2 = assert_cast<ColumnUInt8&>(assert_cast<ColumnArray&>(*column_result2).get_data())
.get_data();
for (int i = 0; i < result2.size(); i++) {
EXPECT_EQ(result2[i], 1);
}
EXPECT_EQ(column_result2->get_offsets()[-1], 0);
EXPECT_EQ(column_result2->get_offsets()[0], 3);
EXPECT_EQ(column_result2->get_offsets().size(), 1);
agg_function->destroy(place2);
}
TEST_F(VRetentionTest, testNoMerge) {
const int batch_size = 4;
auto column_event1 = ColumnVector<UInt8>::create();
column_event1->insert(0);
column_event1->insert(1);
column_event1->insert(0);
column_event1->insert(0);
auto column_event2 = ColumnVector<UInt8>::create();
column_event2->insert(0);
column_event2->insert(0);
column_event2->insert(1);
column_event2->insert(0);
auto column_event3 = ColumnVector<UInt8>::create();
column_event3->insert(0);
column_event3->insert(0);
column_event3->insert(0);
column_event3->insert(1);
std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
agg_function->create(place);
const IColumn* column[3] = {column_event1.get(), column_event2.get(), column_event3.get()};
for (int i = 0; i < batch_size; i++) {
agg_function->add(place, column, i, nullptr);
}
auto column_result =
ColumnArray::create(((DataTypePtr)std::make_shared<DataTypeUInt8>())->create_column());
agg_function->insert_result_into(place, *column_result);
auto& result = assert_cast<ColumnUInt8&>(assert_cast<ColumnArray&>(*column_result).get_data())
.get_data();
for (int i = 0; i < result.size(); i++) {
EXPECT_EQ(result[i], 1);
}
EXPECT_EQ(column_result->get_offsets()[-1], 0);
EXPECT_EQ(column_result->get_offsets()[0], 3);
EXPECT_EQ(column_result->get_offsets().size(), 1);
agg_function->destroy(place);
}
TEST_F(VRetentionTest, testSerialize) {
const int batch_size = 2;
auto column_event1 = ColumnVector<UInt8>::create();
column_event1->insert(0);
column_event1->insert(1);
auto column_event2 = ColumnVector<UInt8>::create();
column_event2->insert(0);
column_event2->insert(0);
auto column_event3 = ColumnVector<UInt8>::create();
column_event3->insert(0);
column_event3->insert(0);
std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
AggregateDataPtr place = memory.get();
agg_function->create(place);
const IColumn* column[3] = {column_event1.get(), column_event2.get(), column_event3.get()};
for (int i = 0; i < batch_size; i++) {
agg_function->add(place, column, i, nullptr);
}
ColumnString buf;
VectorBufferWriter buf_writer(buf);
agg_function->serialize(place, buf_writer);
buf_writer.commit();
agg_function->destroy(place);
std::unique_ptr<char[]> memory2(new char[agg_function->size_of_data()]);
AggregateDataPtr place2 = memory2.get();
agg_function->create(place2);
VectorBufferReader buf_reader(buf.get_data_at(0));
agg_function->deserialize(place2, buf_reader, nullptr);
auto column_result =
ColumnArray::create(((DataTypePtr)std::make_shared<DataTypeUInt8>())->create_column());
agg_function->insert_result_into(place2, *column_result);
auto& result = assert_cast<ColumnUInt8&>(assert_cast<ColumnArray&>(*column_result).get_data())
.get_data();
for (int i = 0; i < result.size(); i++) {
if (i == 0) {
EXPECT_EQ(result[i], 1);
} else {
EXPECT_EQ(result[i], 0);
}
}
auto column_event4 = ColumnVector<UInt8>::create();
column_event4->insert(0);
column_event4->insert(0);
auto column_event5 = ColumnVector<UInt8>::create();
column_event5->insert(0);
column_event5->insert(1);
auto column_event6 = ColumnVector<UInt8>::create();
column_event6->insert(0);
column_event6->insert(0);
std::unique_ptr<char[]> memory3(new char[agg_function->size_of_data()]);
AggregateDataPtr place3 = memory3.get();
agg_function->create(place3);
const IColumn* column2[3] = {column_event4.get(), column_event5.get(), column_event6.get()};
for (int i = 0; i < batch_size; i++) {
agg_function->add(place3, column2, i, nullptr);
}
agg_function->merge(place2, place3, nullptr);
auto column_result2 =
ColumnArray::create(((DataTypePtr)std::make_shared<DataTypeUInt8>())->create_column());
agg_function->insert_result_into(place2, *column_result2);
auto& result2 = assert_cast<ColumnUInt8&>(assert_cast<ColumnArray&>(*column_result2).get_data())
.get_data();
for (int i = 0; i < result2.size(); i++) {
if (i == result2.size() - 1) {
EXPECT_EQ(result2[i], 0);
} else {
EXPECT_EQ(result2[i], 1);
}
}
EXPECT_EQ(column_result2->get_offsets()[-1], 0);
EXPECT_EQ(column_result2->get_offsets()[0], 3);
EXPECT_EQ(column_result2->get_offsets().size(), 1);
agg_function->destroy(place2);
agg_function->destroy(place3);
}
} // namespace doris::vectorized

View File

@ -0,0 +1,136 @@
---
{
"title": "RETENTION",
"language": "en"
}
---
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
## RETENTION
### Description
#### Syntax
`retention(event1, event2, ... , eventN);`
The `retention` function takes as arguments a set of conditions from 1 to 32 arguments of type `UInt8` that indicate whether a certain condition was met for the event. Any condition can be specified as an argument.
The conditions, except the first, apply in pairs: the result of the second will be true if the first and second are true, of the third if the first and third are true, etc.
#### Arguments
`event` — An expression that returns a `UInt8` result (1 or 0).
##### Returned value
The array of 1 or 0.
1 — Condition was met for the event.
0 — Condition wasn’t met for the event.
### example
```sql
DROP TABLE IF EXISTS retention_test;
CREATE TABLE retention_test(
`uid` int COMMENT 'user id',
`date` datetime COMMENT 'date time'
)
DUPLICATE KEY(uid)
DISTRIBUTED BY HASH(uid) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);
INSERT into retention_test (uid, date) values (0, '2022-10-12'),
(0, '2022-10-13'),
(0, '2022-10-14'),
(1, '2022-10-12'),
(1, '2022-10-13'),
(2, '2022-10-12');
SELECT * from retention_test;
+------+---------------------+
| uid | date |
+------+---------------------+
| 0 | 2022-10-14 00:00:00 |
| 0 | 2022-10-13 00:00:00 |
| 0 | 2022-10-12 00:00:00 |
| 1 | 2022-10-13 00:00:00 |
| 1 | 2022-10-12 00:00:00 |
| 2 | 2022-10-12 00:00:00 |
+------+---------------------+
SELECT
uid,
retention(date = '2022-10-12')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;
+------+------+
| uid | r |
+------+------+
| 0 | [1] |
| 1 | [1] |
| 2 | [1] |
+------+------+
SELECT
uid,
retention(date = '2022-10-12', date = '2022-10-13')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;
+------+--------+
| uid | r |
+------+--------+
| 0 | [1, 1] |
| 1 | [1, 1] |
| 2 | [1, 0] |
+------+--------+
SELECT
uid,
retention(date = '2022-10-12', date = '2022-10-13', date = '2022-10-14')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;
+------+-----------+
| uid | r |
+------+-----------+
| 0 | [1, 1, 1] |
| 1 | [1, 1, 0] |
| 2 | [1, 0, 0] |
+------+-----------+
```
### keywords
RETENTION

View File

@ -429,7 +429,8 @@
"sql-manual/sql-functions/aggregate-functions/any_value",
"sql-manual/sql-functions/aggregate-functions/var_samp",
"sql-manual/sql-functions/aggregate-functions/approx_count_distinct",
"sql-manual/sql-functions/aggregate-functions/variance"
"sql-manual/sql-functions/aggregate-functions/variance",
"sql-manual/sql-functions/aggregate-functions/retention"
]
},
{

View File

@ -0,0 +1,136 @@
---
{
"title": "RETENTION",
"language": "zh-CN"
}
---
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
## RETENTION
### description
#### Syntax
`retention(event1, event2, ... , eventN);`
留存函数将一组条件作为参数,类型为1到32个`UInt8`类型的参数,用来表示事件是否满足特定条件。 任何条件都可以指定为参数.
除了第一个以外,条件成对适用:如果第一个和第二个是真的,第二个结果将是真的,如果第一个和第三个是真的,第三个结果将是真的,等等。
#### Arguments
`event` — 返回`UInt8`结果(1或0)的表达式.
##### Returned value
由1和0组成的数组。
1 — 条件满足。
0 — 条件不满足
### example
```sql
DROP TABLE IF EXISTS retention_test;
CREATE TABLE retention_test(
`uid` int COMMENT 'user id',
`date` datetime COMMENT 'date time'
)
DUPLICATE KEY(uid)
DISTRIBUTED BY HASH(uid) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);
INSERT into retention_test (uid, date) values (0, '2022-10-12'),
(0, '2022-10-13'),
(0, '2022-10-14'),
(1, '2022-10-12'),
(1, '2022-10-13'),
(2, '2022-10-12');
SELECT * from retention_test;
+------+---------------------+
| uid | date |
+------+---------------------+
| 0 | 2022-10-14 00:00:00 |
| 0 | 2022-10-13 00:00:00 |
| 0 | 2022-10-12 00:00:00 |
| 1 | 2022-10-13 00:00:00 |
| 1 | 2022-10-12 00:00:00 |
| 2 | 2022-10-12 00:00:00 |
+------+---------------------+
SELECT
uid,
retention(date = '2022-10-12')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;
+------+------+
| uid | r |
+------+------+
| 0 | [1] |
| 1 | [1] |
| 2 | [1] |
+------+------+
SELECT
uid,
retention(date = '2022-10-12', date = '2022-10-13')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;
+------+--------+
| uid | r |
+------+--------+
| 0 | [1, 1] |
| 1 | [1, 1] |
| 2 | [1, 0] |
+------+--------+
SELECT
uid,
retention(date = '2022-10-12', date = '2022-10-13', date = '2022-10-14')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;
+------+-----------+
| uid | r |
+------+-----------+
| 0 | [1, 1, 1] |
| 1 | [1, 1, 0] |
| 2 | [1, 0, 0] |
+------+-----------+
```
### keywords
RETENTION

View File

@ -77,7 +77,7 @@ public class FunctionCallExpr extends Expr {
private static final ImmutableSet<String> DECIMAL_WIDER_TYPE_SET =
new ImmutableSortedSet.Builder(String.CASE_INSENSITIVE_ORDER)
.add("sum").add("avg").add("multi_distinct_sum").build();
private static final ImmutableSet<String> DECIMAL_FUNCTION_SET =
private static final ImmutableSet<String> DECIMAL_FUNCTION_SET =
new ImmutableSortedSet.Builder<>(String.CASE_INSENSITIVE_ORDER)
.addAll(DECIMAL_SAME_TYPE_SET)
.addAll(DECIMAL_WIDER_TYPE_SET)
@ -173,11 +173,11 @@ public class FunctionCallExpr extends Expr {
if (!orderByElements.isEmpty()) {
if (!VectorizedUtil.isVectorized()) {
throw new AnalysisException(
"ORDER BY for arguments only support in vec exec engine");
"ORDER BY for arguments only support in vec exec engine");
} else if (!AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains(
fnName.getFunction().toLowerCase())) {
throw new AnalysisException(
"ORDER BY not support for the function:" + fnName.getFunction().toLowerCase());
"ORDER BY not support for the function:" + fnName.getFunction().toLowerCase());
}
}
setChildren();
@ -845,6 +845,7 @@ public class FunctionCallExpr extends Expr {
/**
* This analyzeImp used for DefaultValueExprDef
* to generate a builtinFunction.
*
* @throws AnalysisException
*/
public void analyzeImplForDefaultValue(Type type) throws AnalysisException {
@ -907,7 +908,7 @@ public class FunctionCallExpr extends Expr {
if (!VectorizedUtil.isVectorized()) {
type = getChild(0).type.getMaxResolutionType();
}
fn = getBuiltinFunction(fnName.getFunction(), new Type[]{type},
fn = getBuiltinFunction(fnName.getFunction(), new Type[] {type},
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else if (fnName.getFunction().equalsIgnoreCase("count_distinct")) {
Type compatibleType = this.children.get(0).getType();
@ -920,7 +921,7 @@ public class FunctionCallExpr extends Expr {
}
}
fn = getBuiltinFunction(fnName.getFunction(), new Type[]{compatibleType},
fn = getBuiltinFunction(fnName.getFunction(), new Type[] {compatibleType},
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else if (fnName.getFunction().equalsIgnoreCase(FunctionSet.WINDOW_FUNNEL)) {
if (fnParams.exprs() == null || fnParams.exprs().size() < 4) {
@ -962,6 +963,21 @@ public class FunctionCallExpr extends Expr {
// cast date to datetime
uncheckedCastChild(ScalarType.DATETIMEV2, 2);
}
} else if (fnName.getFunction().equalsIgnoreCase(FunctionSet.RETENTION)) {
if (this.children.isEmpty()) {
throw new AnalysisException("The " + fnName + " function must have at least one param");
}
Type[] childTypes = new Type[children.size()];
for (int i = 0; i < children.size(); i++) {
if (children.get(i).type != Type.BOOLEAN) {
throw new AnalysisException("All params of "
+ fnName + " function must be boolean");
}
childTypes[i] = children.get(i).type;
}
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else if (fnName.getFunction().equalsIgnoreCase("if")) {
Type[] childTypes = collectChildReturnTypes();
Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[1], childTypes[2], true);
@ -980,7 +996,7 @@ public class FunctionCallExpr extends Expr {
Type[] newChildTypes = new Type[children.size() - orderByElements.size()];
System.arraycopy(childTypes, 0, newChildTypes, 0, newChildTypes.length);
fn = getBuiltinFunction(fnName.getFunction(), newChildTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else {
// now first find table function in table function sets
if (isTableFnCall) {
@ -1457,7 +1473,7 @@ public class FunctionCallExpr extends Expr {
|| fnName.getFunction().equalsIgnoreCase("avg")
|| fnName.getFunction().equalsIgnoreCase("weekOfYear")) {
Type childType = getChild(0).type;
fn = getBuiltinFunction(fnName.getFunction(), new Type[]{childType},
fn = getBuiltinFunction(fnName.getFunction(), new Type[] {childType},
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
type = fn.getReturnType();
}

View File

@ -50,7 +50,8 @@ public class AggregateFunction extends Function {
"dense_rank", "multi_distinct_count", "multi_distinct_sum", "hll_union_agg", "hll_union", "bitmap_union",
"bitmap_intersect", "orthogonal_bitmap_intersect", "orthogonal_bitmap_intersect_count", "intersect_count",
"orthogonal_bitmap_union_count", FunctionSet.COUNT, "approx_count_distinct", "ndv",
FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", FunctionSet.WINDOW_FUNNEL);
FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", FunctionSet.WINDOW_FUNNEL,
FunctionSet.RETENTION);
public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx");

View File

@ -1390,6 +1390,8 @@ public class FunctionSet<T> {
public static final String COUNT = "count";
public static final String WINDOW_FUNNEL = "window_funnel";
public static final String RETENTION = "retention";
// Populate all the aggregate builtins in the catalog.
// null symbols indicate the function does not need that step of the evaluation.
// An empty symbol indicates a TODO for the BE to implement the function.
@ -1466,6 +1468,21 @@ public class FunctionSet<T> {
"",
true, false, true, true));
// retention vectorization
addBuiltin(AggregateFunction.createBuiltin(FunctionSet.RETENTION,
Lists.newArrayList(Type.BOOLEAN),
Type.ARRAY,
Type.VARCHAR,
true,
"",
"",
"",
"",
"",
"",
"",
true, false, true, true));
for (Type t : Type.getSupportedTypes()) {
if (t.isNull()) {
continue; // NULL is handled through type promotion.

View File

@ -88,6 +88,78 @@ public class AggregateTest extends TestWithFeService {
} while (false);
}
@Test
public void testRetentionAnalysisException() throws Exception {
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
// normal.
do {
String query = "select empid, retention(empid = 1, empid = 2) from "
+ DB_NAME + "." + TABLE_NAME + " group by empid";
try {
UtFrameUtils.parseAndAnalyzeStmt(query, ctx);
} catch (Exception e) {
Assert.fail("must be AnalysisException.");
}
} while (false);
do {
String query = "select empid, retention(empid = 1, empid = 2, empid = 3, empid = 4) from "
+ DB_NAME + "." + TABLE_NAME + " group by empid";
try {
UtFrameUtils.parseAndAnalyzeStmt(query, ctx);
} catch (Exception e) {
Assert.fail("must be AnalysisException.");
}
} while (false);
// less argument.
do {
String query = "select empid, retention() from "
+ DB_NAME + "." + TABLE_NAME + " group by empid";
try {
UtFrameUtils.parseAndAnalyzeStmt(query, ctx);
} catch (AnalysisException e) {
Assert.assertTrue(e.getMessage().contains("function must have at least one param"));
break;
} catch (Exception e) {
Assert.fail("must be AnalysisException.");
}
Assert.fail("must be AnalysisException.");
} while (false);
// argument with wrong type.
do {
String query = "select empid, retention('xx', empid = 1) from "
+ DB_NAME + "." + TABLE_NAME + " group by empid";
try {
UtFrameUtils.parseAndAnalyzeStmt(query, ctx);
} catch (AnalysisException e) {
Assert.assertTrue(e.getMessage().contains("All params of retention function must be boolean"));
break;
} catch (Exception e) {
Assert.fail("must be AnalysisException.");
}
Assert.fail("must be AnalysisException.");
} while (false);
do {
String query = "select empid, retention(1) from "
+ DB_NAME + "." + TABLE_NAME + " group by empid";
try {
UtFrameUtils.parseAndAnalyzeStmt(query, ctx);
} catch (AnalysisException e) {
Assert.assertTrue(e.getMessage().contains("All params of retention function must be boolean"));
break;
} catch (Exception e) {
Assert.fail("must be AnalysisException.");
}
Assert.fail("must be AnalysisException.");
} while (false);
}
@Test
public void testWindowFunnelAnalysisException() throws Exception {
ConnectContext ctx = UtFrameUtils.createDefaultCtx();

View File

@ -0,0 +1,59 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !test_aggregate_retention --
0
-- !test_aggregate_retention_2 --
0
-- !test_aggregate_retention_3 --
6
-- !test_aggregate_retention_4 --
0 2022-10-14T00:00
0 2022-10-13T00:00
0 2022-10-12T00:00
1 2022-10-13T00:00
1 2022-10-12T00:00
2 2022-10-12T00:00
-- !test_aggregate_retention_5 --
0 [1]
1 [1]
2 [1]
-- !test_aggregate_retention_6 --
0 [1, 1]
1 [1, 1]
2 [1, 0]
-- !test_aggregate_retention_7 --
0 [1, 1, 1]
1 [1, 1, 0]
2 [1, 0, 0]
-- !test_aggregate_retention_8 --
0
-- !test_aggregate_retention_9 --
0 2022-10-14T00:00
0 2022-10-13T00:00
0 2022-10-12T00:00
1 2022-10-13T00:00
1 2022-10-12T00:00
2 2022-10-12T00:00
-- !test_aggregate_retention_10 --
0 [1]
1 [1]
2 [1]
-- !test_aggregate_retention_11 --
0 [1, 1]
1 [1, 1]
2 [1, 0]
-- !test_aggregate_retention_12 --
0 [1, 1, 1]
1 [1, 1, 0]
2 [1, 0, 0]

View File

@ -0,0 +1,72 @@
DROP TABLE IF EXISTS retention_test;
CREATE TABLE retention_test(
`uid` int COMMENT 'user id',
`date` datetime COMMENT 'date time'
)
DUPLICATE KEY(uid)
DISTRIBUTED BY HASH(uid) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);
INSERT into retention_test (uid, date) values (0, '2022-10-12'),
(0, '2022-10-13'),
(0, '2022-10-14'),
(1, '2022-10-12'),
(1, '2022-10-13'),
(2, '2022-10-12');
SELECT * from retention_test ORDER BY uid;
SELECT
uid,
retention(date = '2022-10-12')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;
SELECT
uid,
retention(date = '2022-10-12', date = '2022-10-13')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;
SELECT
uid,
retention(date = '2022-10-12', date = '2022-10-13', date = '2022-10-14')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;
SET parallel_fragment_exec_instance_num=4;
SELECT * from retention_test ORDER BY uid;
SELECT
uid,
retention(date = '2022-10-12')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;
SELECT
uid,
retention(date = '2022-10-12', date = '2022-10-13')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;
SELECT
uid,
retention(date = '2022-10-12', date = '2022-10-13', date = '2022-10-14')
AS r
FROM retention_test
GROUP BY uid
ORDER BY uid ASC;