Files
doris/be/src/vec/aggregate_functions/aggregate_function_bitmap.h
starocean999 ca0906626f [BUG] fix bitmap function bug (#10502)
* fix bitmap function bug

* add regression test
2022-07-01 15:30:16 +08:00

209 lines
7.7 KiB
C++

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column_complex.h"
#include "vec/columns/column_nullable.h"
#include "vec/common/assert_cast.h"
#include "vec/data_types/data_type_bitmap.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
#include "vec/io/io_helper.h"
namespace doris::vectorized {
struct AggregateFunctionBitmapUnionOp {
static constexpr auto name = "bitmap_union";
template <typename T>
static void add(BitmapValue& res, const T& data, bool& is_first) {
res.add(data);
}
static void add(BitmapValue& res, const BitmapValue& data, bool& is_first) {
if (UNLIKELY(is_first)) {
res = data;
is_first = false;
} else {
res |= data;
}
}
static void merge(BitmapValue& res, const BitmapValue& data, bool& is_first) {
if (UNLIKELY(is_first)) {
res = data;
is_first = false;
} else {
res |= data;
}
}
};
struct AggregateFunctionBitmapIntersectOp {
static constexpr auto name = "bitmap_intersect";
static void add(BitmapValue& res, const BitmapValue& data, bool& is_first) {
if (UNLIKELY(is_first)) {
res = data;
is_first = false;
} else {
res &= data;
}
}
static void merge(BitmapValue& res, const BitmapValue& data, bool& is_first) {
if (UNLIKELY(is_first)) {
res = data;
is_first = false;
} else {
res &= data;
}
}
};
template <typename Op>
struct AggregateFunctionBitmapData {
BitmapValue value;
bool is_first = true;
template <typename T>
void add(const T& data) {
Op::add(value, data, is_first);
}
void merge(const BitmapValue& data) { Op::merge(value, data, is_first); }
void write(BufferWritable& buf) const { DataTypeBitMap::serialize_as_stream(value, buf); }
void read(BufferReadable& buf) { DataTypeBitMap::deserialize_as_stream(value, buf); }
void reset() { is_first = true; }
BitmapValue& get() { return value; }
};
template <typename Op>
class AggregateFunctionBitmapOp final
: public IAggregateFunctionDataHelper<AggregateFunctionBitmapData<Op>,
AggregateFunctionBitmapOp<Op>> {
public:
using ResultDataType = BitmapValue;
using ColVecType = ColumnBitmap;
using ColVecResult = ColumnBitmap;
String get_name() const override { return Op::name; }
AggregateFunctionBitmapOp(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<AggregateFunctionBitmapData<Op>,
AggregateFunctionBitmapOp<Op>>(argument_types_, {}) {}
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeBitMap>(); }
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena*) const override {
const auto& column = static_cast<const ColVecType&>(*columns[0]);
this->data(place).add(column.get_data()[row_num]);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
this->data(place).merge(
const_cast<AggregateFunctionBitmapData<Op>&>(this->data(rhs)).get());
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
this->data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& column = static_cast<ColVecResult&>(to);
column.get_data().push_back(
const_cast<AggregateFunctionBitmapData<Op>&>(this->data(place)).get());
}
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
};
template <bool nullable, typename ColVecType>
class AggregateFunctionBitmapCount final
: public IAggregateFunctionDataHelper<
AggregateFunctionBitmapData<AggregateFunctionBitmapUnionOp>,
AggregateFunctionBitmapCount<nullable, ColVecType>> {
public:
// using ColVecType = ColumnBitmap;
using ColVecResult = ColumnVector<Int64>;
using AggFunctionData = AggregateFunctionBitmapData<AggregateFunctionBitmapUnionOp>;
AggregateFunctionBitmapCount(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<
AggregateFunctionBitmapData<AggregateFunctionBitmapUnionOp>,
AggregateFunctionBitmapCount<nullable, ColVecType>>(argument_types_, {}) {}
String get_name() const override { return "count"; }
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena*) const override {
if constexpr (nullable) {
auto& nullable_column = assert_cast<const ColumnNullable&>(*columns[0]);
if (!nullable_column.is_null_at(row_num)) {
const auto& column =
static_cast<const ColVecType&>(nullable_column.get_nested_column());
this->data(place).add(column.get_data()[row_num]);
}
} else {
const auto& column = static_cast<const ColVecType&>(*columns[0]);
this->data(place).add(column.get_data()[row_num]);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
this->data(place).merge(const_cast<AggFunctionData&>(this->data(rhs)).get());
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
this->data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& value_data = const_cast<AggFunctionData&>(this->data(place)).get();
auto& column = static_cast<ColVecResult&>(to);
column.get_data().push_back(value_data.cardinality());
}
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
};
AggregateFunctionPtr create_aggregate_function_bitmap_union(const std::string& name,
const DataTypes& argument_types,
const Array& parameters,
const bool result_is_nullable);
} // namespace doris::vectorized