APM Transient Suppressor (TS): integrate VoiceProbabilityDelayUnit
This CL adds a component in the TS implementation to return a delayed version of the voice probability values observed when `Suppress()` is called. That is needed in order to temporally align the voice probability values to the processed audio since TS adds algorithmic delay. Bug: webrtc:13663 Change-Id: I5041ace3939d2ce7ba084ae703428e66f1aa06be Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/255860 Reviewed-by: Hanna Silen <silen@webrtc.org> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Cr-Commit-Position: refs/heads/main@{#36496}
This commit is contained in:

committed by
WebRTC LUCI CQ

parent
26b23b8fcc
commit
7efe5332f2
@ -37,6 +37,7 @@ rtc_library("transient_suppressor_impl") {
|
|||||||
]
|
]
|
||||||
deps = [
|
deps = [
|
||||||
":transient_suppressor_api",
|
":transient_suppressor_api",
|
||||||
|
":voice_probability_delay_unit",
|
||||||
"../../../common_audio:common_audio",
|
"../../../common_audio:common_audio",
|
||||||
"../../../common_audio:common_audio_c",
|
"../../../common_audio:common_audio_c",
|
||||||
"../../../common_audio:fir_filter",
|
"../../../common_audio:fir_filter",
|
||||||
@ -96,6 +97,7 @@ if (rtc_include_tests) {
|
|||||||
"//testing/gtest",
|
"//testing/gtest",
|
||||||
"//third_party/abseil-cpp/absl/flags:flag",
|
"//third_party/abseil-cpp/absl/flags:flag",
|
||||||
"//third_party/abseil-cpp/absl/flags:parse",
|
"//third_party/abseil-cpp/absl/flags:parse",
|
||||||
|
"//third_party/abseil-cpp/absl/types:optional",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -124,5 +126,6 @@ if (rtc_include_tests) {
|
|||||||
"../../../test:test_support",
|
"../../../test:test_support",
|
||||||
"//testing/gtest",
|
"//testing/gtest",
|
||||||
]
|
]
|
||||||
|
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -56,15 +56,18 @@ class TransientSuppressor {
|
|||||||
// of audio. If voice information is not available, `voice_probability` must
|
// of audio. If voice information is not available, `voice_probability` must
|
||||||
// always be set to 1.
|
// always be set to 1.
|
||||||
// `key_pressed` determines if a key was pressed on this audio chunk.
|
// `key_pressed` determines if a key was pressed on this audio chunk.
|
||||||
virtual void Suppress(float* data,
|
// Returns a delayed version of `voice_probability` according to the
|
||||||
size_t data_length,
|
// algorithmic delay introduced by this method. In this way, the modified
|
||||||
int num_channels,
|
// `data` and the returned voice probability will be temporally aligned.
|
||||||
const float* detection_data,
|
virtual float Suppress(float* data,
|
||||||
size_t detection_length,
|
size_t data_length,
|
||||||
const float* reference_data,
|
int num_channels,
|
||||||
size_t reference_length,
|
const float* detection_data,
|
||||||
float voice_probability,
|
size_t detection_length,
|
||||||
bool key_pressed) = 0;
|
const float* reference_data,
|
||||||
|
size_t reference_length,
|
||||||
|
float voice_probability,
|
||||||
|
bool key_pressed) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace webrtc
|
} // namespace webrtc
|
||||||
|
@ -62,6 +62,7 @@ TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode,
|
|||||||
int detector_rate_hz,
|
int detector_rate_hz,
|
||||||
int num_channels)
|
int num_channels)
|
||||||
: vad_mode_(vad_mode),
|
: vad_mode_(vad_mode),
|
||||||
|
voice_probability_delay_unit_(/*delay_num_samples=*/0, sample_rate_hz),
|
||||||
analyzed_audio_is_silent_(false),
|
analyzed_audio_is_silent_(false),
|
||||||
data_length_(0),
|
data_length_(0),
|
||||||
detection_length_(0),
|
detection_length_(0),
|
||||||
@ -125,6 +126,9 @@ void TransientSuppressorImpl::Initialize(int sample_rate_hz,
|
|||||||
RTC_DCHECK_LE(data_length_, analysis_length_);
|
RTC_DCHECK_LE(data_length_, analysis_length_);
|
||||||
buffer_delay_ = analysis_length_ - data_length_;
|
buffer_delay_ = analysis_length_ - data_length_;
|
||||||
|
|
||||||
|
voice_probability_delay_unit_.Initialize(/*delay_num_samples=*/buffer_delay_,
|
||||||
|
sample_rate_hz);
|
||||||
|
|
||||||
complex_analysis_length_ = analysis_length_ / 2 + 1;
|
complex_analysis_length_ = analysis_length_ / 2 + 1;
|
||||||
RTC_DCHECK_GE(complex_analysis_length_, kMaxVoiceBin);
|
RTC_DCHECK_GE(complex_analysis_length_, kMaxVoiceBin);
|
||||||
num_channels_ = num_channels;
|
num_channels_ = num_channels;
|
||||||
@ -175,19 +179,21 @@ void TransientSuppressorImpl::Initialize(int sample_rate_hz,
|
|||||||
using_reference_ = false;
|
using_reference_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TransientSuppressorImpl::Suppress(float* data,
|
float TransientSuppressorImpl::Suppress(float* data,
|
||||||
size_t data_length,
|
size_t data_length,
|
||||||
int num_channels,
|
int num_channels,
|
||||||
const float* detection_data,
|
const float* detection_data,
|
||||||
size_t detection_length,
|
size_t detection_length,
|
||||||
const float* reference_data,
|
const float* reference_data,
|
||||||
size_t reference_length,
|
size_t reference_length,
|
||||||
float voice_probability,
|
float voice_probability,
|
||||||
bool key_pressed) {
|
bool key_pressed) {
|
||||||
if (!data || data_length != data_length_ || num_channels != num_channels_ ||
|
if (!data || data_length != data_length_ || num_channels != num_channels_ ||
|
||||||
detection_length != detection_length_ || voice_probability < 0 ||
|
detection_length != detection_length_ || voice_probability < 0 ||
|
||||||
voice_probability > 1) {
|
voice_probability > 1) {
|
||||||
return;
|
// The audio is not modified, so the voice probability is returned as is
|
||||||
|
// (delay not applied).
|
||||||
|
return voice_probability;
|
||||||
}
|
}
|
||||||
|
|
||||||
UpdateKeypress(key_pressed);
|
UpdateKeypress(key_pressed);
|
||||||
@ -205,7 +211,9 @@ void TransientSuppressorImpl::Suppress(float* data,
|
|||||||
float detector_result = detector_->Detect(detection_data, detection_length,
|
float detector_result = detector_->Detect(detection_data, detection_length,
|
||||||
reference_data, reference_length);
|
reference_data, reference_length);
|
||||||
if (detector_result < 0) {
|
if (detector_result < 0) {
|
||||||
return;
|
// The audio is not modified, so the voice probability is returned as is
|
||||||
|
// (delay not applied).
|
||||||
|
return voice_probability;
|
||||||
}
|
}
|
||||||
|
|
||||||
using_reference_ = detector_->using_reference();
|
using_reference_ = detector_->using_reference();
|
||||||
@ -235,6 +243,9 @@ void TransientSuppressorImpl::Suppress(float* data,
|
|||||||
: &in_buffer_[i * analysis_length_],
|
: &in_buffer_[i * analysis_length_],
|
||||||
data_length_ * sizeof(*data));
|
data_length_ * sizeof(*data));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The audio has been modified, return the delayed voice probability.
|
||||||
|
return voice_probability_delay_unit_.Delay(voice_probability);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This should only be called when detection is enabled. UpdateBuffers() must
|
// This should only be called when detection is enabled. UpdateBuffers() must
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "modules/audio_processing/transient/transient_suppressor.h"
|
#include "modules/audio_processing/transient/transient_suppressor.h"
|
||||||
|
#include "modules/audio_processing/transient/voice_probability_delay_unit.h"
|
||||||
#include "rtc_base/gtest_prod_util.h"
|
#include "rtc_base/gtest_prod_util.h"
|
||||||
|
|
||||||
namespace webrtc {
|
namespace webrtc {
|
||||||
@ -37,18 +38,18 @@ class TransientSuppressorImpl : public TransientSuppressor {
|
|||||||
int detector_rate_hz,
|
int detector_rate_hz,
|
||||||
int num_channels) override;
|
int num_channels) override;
|
||||||
|
|
||||||
void Suppress(float* data,
|
float Suppress(float* data,
|
||||||
size_t data_length,
|
size_t data_length,
|
||||||
int num_channels,
|
int num_channels,
|
||||||
const float* detection_data,
|
const float* detection_data,
|
||||||
size_t detection_length,
|
size_t detection_length,
|
||||||
const float* reference_data,
|
const float* reference_data,
|
||||||
size_t reference_length,
|
size_t reference_length,
|
||||||
float voice_probability,
|
float voice_probability,
|
||||||
bool key_pressed) override;
|
bool key_pressed) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FRIEND_TEST_ALL_PREFIXES(TransientSuppressorImplTest,
|
FRIEND_TEST_ALL_PREFIXES(TransientSuppressorVadModeParametrization,
|
||||||
TypingDetectionLogicWorksAsExpectedForMono);
|
TypingDetectionLogicWorksAsExpectedForMono);
|
||||||
void Suppress(float* in_ptr, float* spectral_mean, float* out_ptr);
|
void Suppress(float* in_ptr, float* spectral_mean, float* out_ptr);
|
||||||
|
|
||||||
@ -61,6 +62,7 @@ class TransientSuppressorImpl : public TransientSuppressor {
|
|||||||
void SoftRestoration(float* spectral_mean);
|
void SoftRestoration(float* spectral_mean);
|
||||||
|
|
||||||
const VadMode vad_mode_;
|
const VadMode vad_mode_;
|
||||||
|
VoiceProbabilityDelayUnit voice_probability_delay_unit_;
|
||||||
|
|
||||||
std::unique_ptr<TransientDetector> detector_;
|
std::unique_ptr<TransientDetector> detector_;
|
||||||
|
|
||||||
|
@ -10,21 +10,37 @@
|
|||||||
|
|
||||||
#include "modules/audio_processing/transient/transient_suppressor.h"
|
#include "modules/audio_processing/transient/transient_suppressor.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "modules/audio_processing/transient/common.h"
|
#include "modules/audio_processing/transient/common.h"
|
||||||
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
|
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
|
||||||
#include "test/gtest.h"
|
#include "test/gtest.h"
|
||||||
|
|
||||||
namespace webrtc {
|
namespace webrtc {
|
||||||
|
namespace {
|
||||||
|
constexpr int kMono = 1;
|
||||||
|
|
||||||
class TransientSuppressorImplTest
|
// Returns the index of the first non-zero sample in `samples` or an unspecified
|
||||||
|
// value if no value is zero.
|
||||||
|
absl::optional<int> FindFirstNonZeroSample(const std::vector<float>& samples) {
|
||||||
|
for (size_t i = 0; i < samples.size(); ++i) {
|
||||||
|
if (samples[i] != 0.0f) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class TransientSuppressorVadModeParametrization
|
||||||
: public ::testing::TestWithParam<TransientSuppressor::VadMode> {};
|
: public ::testing::TestWithParam<TransientSuppressor::VadMode> {};
|
||||||
|
|
||||||
TEST_P(TransientSuppressorImplTest,
|
TEST_P(TransientSuppressorVadModeParametrization,
|
||||||
TypingDetectionLogicWorksAsExpectedForMono) {
|
TypingDetectionLogicWorksAsExpectedForMono) {
|
||||||
static const int kNumChannels = 1;
|
|
||||||
|
|
||||||
TransientSuppressorImpl ts(GetParam(), ts::kSampleRate16kHz,
|
TransientSuppressorImpl ts(GetParam(), ts::kSampleRate16kHz,
|
||||||
ts::kSampleRate16kHz, kNumChannels);
|
ts::kSampleRate16kHz, kMono);
|
||||||
|
|
||||||
// Each key-press enables detection.
|
// Each key-press enables detection.
|
||||||
EXPECT_FALSE(ts.detection_enabled_);
|
EXPECT_FALSE(ts.detection_enabled_);
|
||||||
@ -88,10 +104,72 @@ TEST_P(TransientSuppressorImplTest,
|
|||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
,
|
|
||||||
TransientSuppressorImplTest,
|
TransientSuppressorImplTest,
|
||||||
|
TransientSuppressorVadModeParametrization,
|
||||||
::testing::Values(TransientSuppressor::VadMode::kDefault,
|
::testing::Values(TransientSuppressor::VadMode::kDefault,
|
||||||
TransientSuppressor::VadMode::kRnnVad,
|
TransientSuppressor::VadMode::kRnnVad,
|
||||||
TransientSuppressor::VadMode::kNoVad));
|
TransientSuppressor::VadMode::kNoVad));
|
||||||
|
|
||||||
|
class TransientSuppressorSampleRateParametrization
|
||||||
|
: public ::testing::TestWithParam<int> {};
|
||||||
|
|
||||||
|
// Checks that voice probability and processed audio data are temporally aligned
|
||||||
|
// after `Suppress()` is called.
|
||||||
|
TEST_P(TransientSuppressorSampleRateParametrization,
|
||||||
|
CheckAudioAndVoiceProbabilityTemporallyAligned) {
|
||||||
|
const int sample_rate_hz = GetParam();
|
||||||
|
TransientSuppressorImpl ts(TransientSuppressor::VadMode::kDefault,
|
||||||
|
sample_rate_hz,
|
||||||
|
/*detection_rate_hz=*/sample_rate_hz, kMono);
|
||||||
|
|
||||||
|
const int frame_size = sample_rate_hz * ts::kChunkSizeMs / 1000;
|
||||||
|
std::vector<float> frame(frame_size);
|
||||||
|
|
||||||
|
constexpr int kMaxAttempts = 3;
|
||||||
|
for (int i = 0; i < kMaxAttempts; ++i) {
|
||||||
|
SCOPED_TRACE(i);
|
||||||
|
|
||||||
|
// Call `Suppress()` on frames of non-zero audio samples.
|
||||||
|
std::fill(frame.begin(), frame.end(), 1000.0f);
|
||||||
|
float delayed_voice_probability = ts.Suppress(
|
||||||
|
frame.data(), frame.size(), kMono, /*detection_data=*/nullptr,
|
||||||
|
/*detection_length=*/frame_size, /*reference_data=*/nullptr,
|
||||||
|
/*reference_length=*/frame_size, /*voice_probability=*/1.0f,
|
||||||
|
/*key_pressed=*/false);
|
||||||
|
|
||||||
|
// Detect the algorithmic delay of `TransientSuppressorImpl`.
|
||||||
|
absl::optional<int> frame_delay = FindFirstNonZeroSample(frame);
|
||||||
|
|
||||||
|
// Check that the delayed voice probability is delayed according to the
|
||||||
|
// measured delay.
|
||||||
|
if (frame_delay.has_value()) {
|
||||||
|
if (*frame_delay == 0) {
|
||||||
|
// When the delay is a multiple integer of the frame duration,
|
||||||
|
// `Suppress()` returns a copy of a previously observed voice
|
||||||
|
// probability value.
|
||||||
|
EXPECT_EQ(delayed_voice_probability, 1.0f);
|
||||||
|
} else {
|
||||||
|
// Instead, when the delay is fractional, `Suppress()` returns an
|
||||||
|
// interpolated value. Since the exact value depends on the
|
||||||
|
// interpolation method, we only check that the delayed voice
|
||||||
|
// probability is not zero as it must converge towards the previoulsy
|
||||||
|
// observed value.
|
||||||
|
EXPECT_GT(delayed_voice_probability, 0.0f);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
// The algorithmic delay is longer than the duration of a single frame.
|
||||||
|
// Until the delay is detected, the delayed voice probability is zero.
|
||||||
|
EXPECT_EQ(delayed_voice_probability, 0.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TransientSuppressorImplTest,
|
||||||
|
TransientSuppressorSampleRateParametrization,
|
||||||
|
::testing::Values(ts::kSampleRate8kHz,
|
||||||
|
ts::kSampleRate16kHz,
|
||||||
|
ts::kSampleRate32kHz,
|
||||||
|
ts::kSampleRate48kHz));
|
||||||
|
|
||||||
} // namespace webrtc
|
} // namespace webrtc
|
||||||
|
Reference in New Issue
Block a user