Make transient suppression optionally excludable via defines

This allows clients to exclude the transient suppression submodule from WebRTC builds, by defining WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR.

The changes have been shown to be bitexact for a test dataset (when the flag is _not_ defined.)

No-Try: True
Bug: webrtc:11226, webrtc:11292
Change-Id: I6931c82a280a9b40a53ee1c2a9820ed9e674a9a5
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/171421
Commit-Queue: Sam Zackrisson <saza@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#30978}
This commit is contained in:
saza
2020-04-01 15:24:40 +02:00
committed by Commit Bot
parent fc23cc07e2
commit aa42ecde9a
13 changed files with 299 additions and 134 deletions

View File

@ -184,7 +184,8 @@ rtc_library("audio_processing") {
"agc2:fixed_digital",
"agc2:gain_applier",
"ns",
"transient:transient_suppressor",
"transient:transient_suppressor_api",
"transient:transient_suppressor_creator",
"vad",
"//third_party/abseil-cpp/absl/types:optional",
]

View File

@ -27,6 +27,7 @@
#include "modules/audio_processing/common.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "modules/audio_processing/transient/transient_suppressor_creator.h"
#include "rtc_base/atomic_ops.h"
#include "rtc_base/checks.h"
#include "rtc_base/constructor_magic.h"
@ -1635,12 +1636,18 @@ bool AudioProcessingImpl::UpdateActiveSubmoduleStates() {
void AudioProcessingImpl::InitializeTransientSuppressor() {
if (config_.transient_suppression.enabled) {
// Attempt to create a transient suppressor, if one is not already created.
if (!submodules_.transient_suppressor) {
submodules_.transient_suppressor.reset(new TransientSuppressor());
submodules_.transient_suppressor = CreateTransientSuppressor();
}
if (submodules_.transient_suppressor) {
submodules_.transient_suppressor->Initialize(
proc_fullband_sample_rate_hz(), capture_nonlocked_.split_rate,
num_proc_channels());
} else {
RTC_LOG(LS_WARNING)
<< "No transient suppressor created (probably disabled)";
}
submodules_.transient_suppressor->Initialize(proc_fullband_sample_rate_hz(),
capture_nonlocked_.split_rate,
num_proc_channels());
} else {
submodules_.transient_suppressor.reset();
}
@ -1843,28 +1850,28 @@ void AudioProcessingImpl::InitializeNoiseSuppressor() {
submodules_.noise_suppressor.reset();
if (config_.noise_suppression.enabled) {
auto map_level =
[](AudioProcessing::Config::NoiseSuppression::Level level) {
using NoiseSuppresionConfig =
AudioProcessing::Config::NoiseSuppression;
switch (level) {
case NoiseSuppresionConfig::kLow:
return NsConfig::SuppressionLevel::k6dB;
case NoiseSuppresionConfig::kModerate:
return NsConfig::SuppressionLevel::k12dB;
case NoiseSuppresionConfig::kHigh:
return NsConfig::SuppressionLevel::k18dB;
case NoiseSuppresionConfig::kVeryHigh:
return NsConfig::SuppressionLevel::k21dB;
default:
RTC_NOTREACHED();
}
};
auto map_level =
[](AudioProcessing::Config::NoiseSuppression::Level level) {
using NoiseSuppresionConfig =
AudioProcessing::Config::NoiseSuppression;
switch (level) {
case NoiseSuppresionConfig::kLow:
return NsConfig::SuppressionLevel::k6dB;
case NoiseSuppresionConfig::kModerate:
return NsConfig::SuppressionLevel::k12dB;
case NoiseSuppresionConfig::kHigh:
return NsConfig::SuppressionLevel::k18dB;
case NoiseSuppresionConfig::kVeryHigh:
return NsConfig::SuppressionLevel::k21dB;
default:
RTC_NOTREACHED();
}
};
NsConfig cfg;
cfg.target_level = map_level(config_.noise_suppression.level);
submodules_.noise_suppressor = std::make_unique<NoiseSuppressor>(
cfg, proc_sample_rate_hz(), num_proc_channels());
NsConfig cfg;
cfg.target_level = map_level(config_.noise_suppression.level);
submodules_.noise_suppressor = std::make_unique<NoiseSuppressor>(
cfg, proc_sample_rate_hz(), num_proc_channels());
}
}

View File

@ -8,7 +8,28 @@
import("../../../webrtc.gni")
rtc_library("transient_suppressor") {
rtc_source_set("transient_suppressor_api") {
sources = [ "transient_suppressor.h" ]
}
rtc_library("transient_suppressor_creator") {
sources = [
"transient_suppressor_creator.cc",
"transient_suppressor_creator.h",
]
deps = [
":transient_suppressor_api",
":transient_suppressor_impl",
]
}
rtc_library("transient_suppressor_impl") {
visibility = [
":transient_suppressor_creator",
":transient_suppression_test",
":transient_suppression_unittests",
":click_annotate",
]
sources = [
"common.h",
"daubechies_8_wavelet_coeffs.h",
@ -17,8 +38,8 @@ rtc_library("transient_suppressor") {
"moving_moments.h",
"transient_detector.cc",
"transient_detector.h",
"transient_suppressor.cc",
"transient_suppressor.h",
"transient_suppressor_impl.cc",
"transient_suppressor_impl.h",
"windows_private.h",
"wpd_node.cc",
"wpd_node.h",
@ -26,6 +47,7 @@ rtc_library("transient_suppressor") {
"wpd_tree.h",
]
deps = [
":transient_suppressor_api",
"../../../common_audio:common_audio",
"../../../common_audio:common_audio_c",
"../../../common_audio:fir_filter",
@ -46,7 +68,7 @@ if (rtc_include_tests) {
"file_utils.h",
]
deps = [
":transient_suppressor",
":transient_suppressor_impl",
"..:audio_processing",
"../../../rtc_base/system:file_wrapper",
"../../../system_wrappers",
@ -61,7 +83,7 @@ if (rtc_include_tests) {
"transient_suppression_test.cc",
]
deps = [
":transient_suppressor",
":transient_suppressor_impl",
"..:audio_processing",
"../../../common_audio",
"../../../rtc_base:rtc_base_approved",
@ -90,7 +112,7 @@ if (rtc_include_tests) {
"wpd_tree_unittest.cc",
]
deps = [
":transient_suppressor",
":transient_suppressor_impl",
"../../../rtc_base:stringutils",
"../../../rtc_base/system:file_wrapper",
"../../../test:fileutils",

View File

@ -20,7 +20,7 @@
#include "absl/flags/parse.h"
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc/agc.h"
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
#include "test/gtest.h"
#include "test/testsupport/file_utils.h"
@ -165,7 +165,7 @@ void void_main() {
Agc agc;
TransientSuppressor suppressor;
TransientSuppressorImpl suppressor;
suppressor.Initialize(absl::GetFlag(FLAGS_sample_rate_hz), detection_rate_hz,
absl::GetFlag(FLAGS_num_channels));

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
@ -13,23 +13,19 @@
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
class TransientDetector;
// Detects transients in an audio stream and suppress them using a simple
// restoration algorithm that attenuates unexpected spikes in the spectrum.
class TransientSuppressor {
public:
TransientSuppressor();
~TransientSuppressor();
virtual ~TransientSuppressor() {}
int Initialize(int sample_rate_hz, int detector_rate_hz, int num_channels);
virtual int Initialize(int sample_rate_hz,
int detector_rate_hz,
int num_channels) = 0;
// Processes a |data| chunk, and returns it with keystrokes suppressed from
// it. The float format is assumed to be int16 ranged. If there are more than
@ -48,71 +44,15 @@ class TransientSuppressor {
// always be set to 1.
// |key_pressed| determines if a key was pressed on this audio chunk.
// Returns 0 on success and -1 otherwise.
int Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed);
private:
FRIEND_TEST_ALL_PREFIXES(TransientSuppressorTest,
TypingDetectionLogicWorksAsExpectedForMono);
void Suppress(float* in_ptr, float* spectral_mean, float* out_ptr);
void UpdateKeypress(bool key_pressed);
void UpdateRestoration(float voice_probability);
void UpdateBuffers(float* data);
void HardRestoration(float* spectral_mean);
void SoftRestoration(float* spectral_mean);
std::unique_ptr<TransientDetector> detector_;
size_t data_length_;
size_t detection_length_;
size_t analysis_length_;
size_t buffer_delay_;
size_t complex_analysis_length_;
int num_channels_;
// Input buffer where the original samples are stored.
std::unique_ptr<float[]> in_buffer_;
std::unique_ptr<float[]> detection_buffer_;
// Output buffer where the restored samples are stored.
std::unique_ptr<float[]> out_buffer_;
// Arrays for fft.
std::unique_ptr<size_t[]> ip_;
std::unique_ptr<float[]> wfft_;
std::unique_ptr<float[]> spectral_mean_;
// Stores the data for the fft.
std::unique_ptr<float[]> fft_buffer_;
std::unique_ptr<float[]> magnitudes_;
const float* window_;
std::unique_ptr<float[]> mean_factor_;
float detector_smoothed_;
int keypress_counter_;
int chunks_since_keypress_;
bool detection_enabled_;
bool suppression_enabled_;
bool use_hard_restoration_;
int chunks_since_voice_change_;
uint32_t seed_;
bool using_reference_;
virtual int Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) = 0;
};
} // namespace webrtc

View File

@ -0,0 +1,27 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/transient/transient_suppressor_creator.h"
#include <memory>
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
namespace webrtc {
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor() {
#ifdef WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR
return nullptr;
#else
return std::make_unique<TransientSuppressorImpl>();
#endif
}
} // namespace webrtc

View File

@ -0,0 +1,26 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_CREATOR_H_
#define MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_CREATOR_H_
#include <memory>
#include "modules/audio_processing/transient/transient_suppressor.h"
namespace webrtc {
// Creates a transient suppressor.
// Will return nullptr if WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR is defined.
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor();
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_CREATOR_H_

View File

@ -8,13 +8,15 @@
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
#include <string.h>
#include <algorithm>
#include <cmath>
#include <complex>
#include <deque>
#include <limits>
#include <set>
#include "common_audio/include/audio_util.h"
@ -22,6 +24,7 @@
#include "common_audio/third_party/fft4g/fft4g.h"
#include "modules/audio_processing/transient/common.h"
#include "modules/audio_processing/transient/transient_detector.h"
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/windows_private.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
@ -43,7 +46,7 @@ float ComplexMagnitude(float a, float b) {
} // namespace
TransientSuppressor::TransientSuppressor()
TransientSuppressorImpl::TransientSuppressorImpl()
: data_length_(0),
detection_length_(0),
analysis_length_(0),
@ -61,11 +64,11 @@ TransientSuppressor::TransientSuppressor()
seed_(182),
using_reference_(false) {}
TransientSuppressor::~TransientSuppressor() {}
TransientSuppressorImpl::~TransientSuppressorImpl() {}
int TransientSuppressor::Initialize(int sample_rate_hz,
int detection_rate_hz,
int num_channels) {
int TransientSuppressorImpl::Initialize(int sample_rate_hz,
int detection_rate_hz,
int num_channels) {
switch (sample_rate_hz) {
case ts::kSampleRate8kHz:
analysis_length_ = 128u;
@ -155,15 +158,15 @@ int TransientSuppressor::Initialize(int sample_rate_hz,
return 0;
}
int TransientSuppressor::Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) {
int TransientSuppressorImpl::Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) {
if (!data || data_length != data_length_ || num_channels != num_channels_ ||
detection_length != detection_length_ || voice_probability < 0 ||
voice_probability > 1) {
@ -222,9 +225,9 @@ int TransientSuppressor::Suppress(float* data,
// This should only be called when detection is enabled. UpdateBuffers() must
// have been called. At return, |out_buffer_| will be filled with the
// processed output.
void TransientSuppressor::Suppress(float* in_ptr,
float* spectral_mean,
float* out_ptr) {
void TransientSuppressorImpl::Suppress(float* in_ptr,
float* spectral_mean,
float* out_ptr) {
// Go to frequency domain.
for (size_t i = 0; i < analysis_length_; ++i) {
// TODO(aluebs): Rename windows
@ -270,7 +273,7 @@ void TransientSuppressor::Suppress(float* in_ptr,
}
}
void TransientSuppressor::UpdateKeypress(bool key_pressed) {
void TransientSuppressorImpl::UpdateKeypress(bool key_pressed) {
const int kKeypressPenalty = 1000 / ts::kChunkSizeMs;
const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs;
const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs; // 4 seconds.
@ -300,7 +303,7 @@ void TransientSuppressor::UpdateKeypress(bool key_pressed) {
}
}
void TransientSuppressor::UpdateRestoration(float voice_probability) {
void TransientSuppressorImpl::UpdateRestoration(float voice_probability) {
const int kHardRestorationOffsetDelay = 3;
const int kHardRestorationOnsetDelay = 80;
@ -323,7 +326,7 @@ void TransientSuppressor::UpdateRestoration(float voice_probability) {
// Shift buffers to make way for new data. Must be called after
// |detection_enabled_| is updated by UpdateKeypress().
void TransientSuppressor::UpdateBuffers(float* data) {
void TransientSuppressorImpl::UpdateBuffers(float* data) {
// TODO(aluebs): Change to ring buffer.
memmove(in_buffer_.get(), &in_buffer_[data_length_],
(buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
@ -350,7 +353,7 @@ void TransientSuppressor::UpdateBuffers(float* data) {
// Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
// the spectral mean. The attenuation depends on |detector_smoothed_|.
// If a restoration takes place, the |magnitudes_| are updated to the new value.
void TransientSuppressor::HardRestoration(float* spectral_mean) {
void TransientSuppressorImpl::HardRestoration(float* spectral_mean) {
const float detector_result =
1.f - std::pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
// To restore, we get the peaks in the spectrum. If higher than the previous
@ -377,7 +380,7 @@ void TransientSuppressor::HardRestoration(float* spectral_mean) {
// the spectral mean and that is lower than some function of the current block
// frequency mean. The attenuation depends on |detector_smoothed_|.
// If a restoration takes place, the |magnitudes_| are updated to the new value.
void TransientSuppressor::SoftRestoration(float* spectral_mean) {
void TransientSuppressorImpl::SoftRestoration(float* spectral_mean) {
// Get the spectral magnitude mean of the current block.
float block_frequency_mean = 0;
for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) {

View File

@ -0,0 +1,123 @@
/*
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_IMPL_H_
#define MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_IMPL_H_
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
class TransientDetector;
// Detects transients in an audio stream and suppress them using a simple
// restoration algorithm that attenuates unexpected spikes in the spectrum.
class TransientSuppressorImpl : public TransientSuppressor {
public:
TransientSuppressorImpl();
~TransientSuppressorImpl() override;
int Initialize(int sample_rate_hz,
int detector_rate_hz,
int num_channels) override;
// Processes a |data| chunk, and returns it with keystrokes suppressed from
// it. The float format is assumed to be int16 ranged. If there are more than
// one channel, the chunks are concatenated one after the other in |data|.
// |data_length| must be equal to |data_length_|.
// |num_channels| must be equal to |num_channels_|.
// A sub-band, ideally the higher, can be used as |detection_data|. If it is
// NULL, |data| is used for the detection too. The |detection_data| is always
// assumed mono.
// If a reference signal (e.g. keyboard microphone) is available, it can be
// passed in as |reference_data|. It is assumed mono and must have the same
// length as |data|. NULL is accepted if unavailable.
// This suppressor performs better if voice information is available.
// |voice_probability| is the probability of voice being present in this chunk
// of audio. If voice information is not available, |voice_probability| must
// always be set to 1.
// |key_pressed| determines if a key was pressed on this audio chunk.
// Returns 0 on success and -1 otherwise.
int Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) override;
private:
FRIEND_TEST_ALL_PREFIXES(TransientSuppressorImplTest,
TypingDetectionLogicWorksAsExpectedForMono);
void Suppress(float* in_ptr, float* spectral_mean, float* out_ptr);
void UpdateKeypress(bool key_pressed);
void UpdateRestoration(float voice_probability);
void UpdateBuffers(float* data);
void HardRestoration(float* spectral_mean);
void SoftRestoration(float* spectral_mean);
std::unique_ptr<TransientDetector> detector_;
size_t data_length_;
size_t detection_length_;
size_t analysis_length_;
size_t buffer_delay_;
size_t complex_analysis_length_;
int num_channels_;
// Input buffer where the original samples are stored.
std::unique_ptr<float[]> in_buffer_;
std::unique_ptr<float[]> detection_buffer_;
// Output buffer where the restored samples are stored.
std::unique_ptr<float[]> out_buffer_;
// Arrays for fft.
std::unique_ptr<size_t[]> ip_;
std::unique_ptr<float[]> wfft_;
std::unique_ptr<float[]> spectral_mean_;
// Stores the data for the fft.
std::unique_ptr<float[]> fft_buffer_;
std::unique_ptr<float[]> magnitudes_;
const float* window_;
std::unique_ptr<float[]> mean_factor_;
float detector_smoothed_;
int keypress_counter_;
int chunks_since_keypress_;
bool detection_enabled_;
bool suppression_enabled_;
bool use_hard_restoration_;
int chunks_since_voice_change_;
uint32_t seed_;
bool using_reference_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_IMPL_H_

View File

@ -8,17 +8,17 @@
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
#include "modules/audio_processing/transient/common.h"
#include "test/gtest.h"
namespace webrtc {
TEST(TransientSuppressorTest, TypingDetectionLogicWorksAsExpectedForMono) {
TEST(TransientSuppressorImplTest, TypingDetectionLogicWorksAsExpectedForMono) {
static const int kNumChannels = 1;
TransientSuppressor ts;
TransientSuppressorImpl ts;
ts.Initialize(ts::kSampleRate16kHz, ts::kSampleRate16kHz, kNumChannels);
// Each key-press enables detection.