AEC3: Suppression filter handles multiple channels

Suppression filter is extended to support the synthesis
of multiple channels. This CL is also a major clean-up of ApplyGain.

The CL has been tested for bit-exactness for single channel output.

Bug: webrtc:10913
Change-Id: I1319f127981552e17dec66701a248d34dcf0e563
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/154341
Commit-Queue: Gustaf Ullberg <gustaf@webrtc.org>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29284}
This commit is contained in:
Gustaf Ullberg
2019-09-24 15:05:04 +02:00
committed by Commit Bot
parent 67309ef93c
commit af3fdc069d
4 changed files with 140 additions and 126 deletions

View File

@ -191,7 +191,9 @@ EchoRemoverImpl::EchoRemoverImpl(const EchoCanceller3Config& config,
subtractors_(num_capture_channels_), subtractors_(num_capture_channels_),
suppression_gains_(num_capture_channels_), suppression_gains_(num_capture_channels_),
cngs_(num_capture_channels_), cngs_(num_capture_channels_),
suppression_filter_(optimization_, sample_rate_hz_), suppression_filter_(optimization_,
sample_rate_hz_,
num_capture_channels_),
render_signal_analyzer_(config_), render_signal_analyzer_(config_),
residual_echo_estimators_(num_capture_channels_), residual_echo_estimators_(num_capture_channels_),
aec_state_(config_), aec_state_(config_),
@ -378,7 +380,7 @@ void EchoRemoverImpl::ProcessCapture(
E2[0], Y2[0], subtractor_output[0], y0); E2[0], Y2[0], subtractor_output[0], y0);
// Choose the linear output. // Choose the linear output.
const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E[0] : Y[0]; const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E : Y;
#if WEBRTC_APM_DEBUG_DUMP #if WEBRTC_APM_DEBUG_DUMP
if (aec_state_.UseLinearFilterOutput()) { if (aec_state_.UseLinearFilterOutput()) {
@ -439,8 +441,7 @@ void EchoRemoverImpl::ProcessCapture(
[](float a, float b) { return std::min(a, b); }); [](float a, float b) { return std::min(a, b); });
} }
// TODO(bugs.webrtc.org/10913): Make ApplyGain handle multiple channels. suppression_filter_.ApplyGain(comfort_noise, high_band_comfort_noise, G,
suppression_filter_.ApplyGain(comfort_noise[0], high_band_comfort_noise[0], G,
high_bands_gain, Y_fft, y); high_bands_gain, Y_fft, y);
// Update the metrics. // Update the metrics.

View File

@ -61,107 +61,117 @@ const float kSqrtHanning[kFftLength] = {
} // namespace } // namespace
SuppressionFilter::SuppressionFilter(Aec3Optimization optimization, SuppressionFilter::SuppressionFilter(Aec3Optimization optimization,
int sample_rate_hz) int sample_rate_hz,
size_t num_capture_channels)
: optimization_(optimization), : optimization_(optimization),
sample_rate_hz_(sample_rate_hz), sample_rate_hz_(sample_rate_hz),
num_capture_channels_(num_capture_channels),
fft_(), fft_(),
e_output_old_(NumBandsForRate(sample_rate_hz_)) { e_output_old_(NumBandsForRate(sample_rate_hz_),
std::vector<std::array<float, kFftLengthBy2>>(
num_capture_channels_)) {
RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); RTC_DCHECK(ValidFullBandRate(sample_rate_hz_));
std::for_each(e_output_old_.begin(), e_output_old_.end(), for (size_t b = 0; b < e_output_old_.size(); ++b) {
[](std::array<float, kFftLengthBy2>& a) { a.fill(0.f); }); for (size_t ch = 0; ch < e_output_old_[b].size(); ++ch) {
e_output_old_[b][ch].fill(0.f);
}
}
} }
SuppressionFilter::~SuppressionFilter() = default; SuppressionFilter::~SuppressionFilter() = default;
void SuppressionFilter::ApplyGain( void SuppressionFilter::ApplyGain(
const FftData& comfort_noise, rtc::ArrayView<const FftData> comfort_noise,
const FftData& comfort_noise_high_band, rtc::ArrayView<const FftData> comfort_noise_high_band,
const std::array<float, kFftLengthBy2Plus1>& suppression_gain, const std::array<float, kFftLengthBy2Plus1>& suppression_gain,
float high_bands_gain, float high_bands_gain,
const FftData& E_lowest_band, rtc::ArrayView<const FftData> E_lowest_band,
std::vector<std::vector<std::vector<float>>>* e) { std::vector<std::vector<std::vector<float>>>* e) {
RTC_DCHECK(e); RTC_DCHECK(e);
RTC_DCHECK_EQ(e->size(), NumBandsForRate(sample_rate_hz_)); RTC_DCHECK_EQ(e->size(), NumBandsForRate(sample_rate_hz_));
FftData E;
// Analysis filterbank.
E.Assign(E_lowest_band);
// Apply gain.
std::transform(suppression_gain.begin(), suppression_gain.end(), E.re.begin(),
E.re.begin(), std::multiplies<float>());
std::transform(suppression_gain.begin(), suppression_gain.end(), E.im.begin(),
E.im.begin(), std::multiplies<float>());
// Comfort noise gain is sqrt(1-g^2), where g is the suppression gain. // Comfort noise gain is sqrt(1-g^2), where g is the suppression gain.
std::array<float, kFftLengthBy2Plus1> noise_gain; std::array<float, kFftLengthBy2Plus1> noise_gain;
std::transform(suppression_gain.begin(), suppression_gain.end(), for (size_t i = 0; i < kFftLengthBy2Plus1; ++i) {
noise_gain.begin(), [](float g) { return 1.f - g * g; }); noise_gain[i] = 1.f - suppression_gain[i] * suppression_gain[i];
}
aec3::VectorMath(optimization_).Sqrt(noise_gain); aec3::VectorMath(optimization_).Sqrt(noise_gain);
// Scale and add the comfort noise. const float high_bands_noise_scaling =
for (size_t k = 0; k < kFftLengthBy2Plus1; k++) { 0.4f * std::sqrt(1.f - high_bands_gain * high_bands_gain);
E.re[k] += noise_gain[k] * comfort_noise.re[k];
E.im[k] += noise_gain[k] * comfort_noise.im[k];
}
// Synthesis filterbank. for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
std::array<float, kFftLength> e_extended; FftData E;
constexpr float kIfftNormalization = 2.f / kFftLength;
fft_.Ifft(E, &e_extended); // Analysis filterbank.
std::transform(e_output_old_[0].begin(), e_output_old_[0].end(), E.Assign(E_lowest_band[ch]);
std::begin(kSqrtHanning) + kFftLengthBy2, (*e)[0][0].begin(),
[&](float a, float b) { return kIfftNormalization * a * b; });
std::transform(e_extended.begin(), e_extended.begin() + kFftLengthBy2,
std::begin(kSqrtHanning), e_extended.begin(),
[&](float a, float b) { return kIfftNormalization * a * b; });
std::transform((*e)[0][0].begin(), (*e)[0][0].end(), e_extended.begin(),
(*e)[0][0].begin(), std::plus<float>());
std::for_each((*e)[0][0].begin(), (*e)[0][0].end(), [](float& x_k) {
x_k = rtc::SafeClamp(x_k, -32768.f, 32767.f);
});
std::copy(e_extended.begin() + kFftLengthBy2, e_extended.begin() + kFftLength,
std::begin(e_output_old_[0]));
if (e->size() > 1) { for (size_t i = 0; i < kFftLengthBy2Plus1; ++i) {
// Form time-domain high-band noise. // Apply suppression gains.
std::array<float, kFftLength> time_domain_high_band_noise; E.re[i] *= suppression_gain[i];
std::transform(comfort_noise_high_band.re.begin(), E.im[i] *= suppression_gain[i];
comfort_noise_high_band.re.end(), E.re.begin(),
[&](float a) { return kIfftNormalization * a; });
std::transform(comfort_noise_high_band.im.begin(),
comfort_noise_high_band.im.end(), E.im.begin(),
[&](float a) { return kIfftNormalization * a; });
fft_.Ifft(E, &time_domain_high_band_noise);
// Scale and apply the noise to the signals. // Scale and add the comfort noise.
const float high_bands_noise_scaling = E.re[i] += noise_gain[i] * comfort_noise[ch].re[i];
0.4f * std::sqrt(1.f - high_bands_gain * high_bands_gain); E.im[i] += noise_gain[i] * comfort_noise[ch].im[i];
std::transform(
(*e)[1][0].begin(), (*e)[1][0].end(),
time_domain_high_band_noise.begin(), (*e)[1][0].begin(),
[&](float a, float b) {
return std::max(
std::min(b * high_bands_noise_scaling + high_bands_gain * a,
32767.0f),
-32768.0f);
});
if (e->size() > 2) {
RTC_DCHECK_EQ(3, e->size());
std::for_each((*e)[2][0].begin(), (*e)[2][0].end(), [&](float& a) {
a = rtc::SafeClamp(a * high_bands_gain, -32768.f, 32767.f);
});
} }
std::array<float, kFftLengthBy2> tmp; // Synthesis filterbank.
for (size_t k = 1; k < e->size(); ++k) { std::array<float, kFftLength> e_extended;
std::copy((*e)[k][0].begin(), (*e)[k][0].end(), tmp.begin()); constexpr float kIfftNormalization = 2.f / kFftLength;
std::copy(e_output_old_[k].begin(), e_output_old_[k].end(), fft_.Ifft(E, &e_extended);
(*e)[k][0].begin());
std::copy(tmp.begin(), tmp.end(), e_output_old_[k].begin()); auto& e0 = (*e)[0][ch];
auto& e0_old = e_output_old_[0][ch];
// Window and add the first half of e_extended with the second half of
// e_extended from the previous block.
for (size_t i = 0; i < kFftLengthBy2; ++i) {
e0[i] = e0_old[i] * kSqrtHanning[kFftLengthBy2 + i];
e0[i] += e_extended[i] * kSqrtHanning[i];
e0[i] *= kIfftNormalization;
}
// The second half of e_extended is stored for the succeeding frame.
std::copy(e_extended.begin() + kFftLengthBy2,
e_extended.begin() + kFftLength, std::begin(e0_old));
// Apply suppression gain to upper bands.
for (size_t b = 1; b < e->size(); ++b) {
auto& e_band = (*e)[b][ch];
for (size_t i = 0; i < kFftLengthBy2; ++i) {
e_band[i] *= high_bands_gain;
}
}
// Add comfort noise to band 1.
if (e->size() > 1) {
E.Assign(comfort_noise_high_band[ch]);
std::array<float, kFftLength> time_domain_high_band_noise;
fft_.Ifft(E, &time_domain_high_band_noise);
auto& e1 = (*e)[1][ch];
const float gain = high_bands_noise_scaling * kIfftNormalization;
for (size_t i = 0; i < kFftLengthBy2; ++i) {
e1[i] += time_domain_high_band_noise[i] * gain;
}
}
// Delay upper bands to match the delay of the filter bank.
for (size_t b = 1; b < e->size(); ++b) {
auto& e_band = (*e)[b][ch];
auto& e_band_old = e_output_old_[b][ch];
for (size_t i = 0; i < kFftLengthBy2; ++i) {
std::swap(e_band[i], e_band_old[i]);
}
}
// Clamp output of all bands.
for (size_t b = 0; b < e->size(); ++b) {
auto& e_band = (*e)[b][ch];
for (size_t i = 0; i < kFftLengthBy2; ++i) {
e_band[i] = rtc::SafeClamp(e_band[i], -32768.f, 32767.f);
}
} }
} }
} }

View File

@ -24,21 +24,24 @@ namespace webrtc {
class SuppressionFilter { class SuppressionFilter {
public: public:
SuppressionFilter(Aec3Optimization optimization, int sample_rate_hz); SuppressionFilter(Aec3Optimization optimization,
int sample_rate_hz,
size_t num_capture_channels_);
~SuppressionFilter(); ~SuppressionFilter();
void ApplyGain(const FftData& comfort_noise, void ApplyGain(rtc::ArrayView<const FftData> comfort_noise,
const FftData& comfort_noise_high_bands, rtc::ArrayView<const FftData> comfort_noise_high_bands,
const std::array<float, kFftLengthBy2Plus1>& suppression_gain, const std::array<float, kFftLengthBy2Plus1>& suppression_gain,
float high_bands_gain, float high_bands_gain,
const FftData& E_lowest_band, rtc::ArrayView<const FftData> E_lowest_band,
std::vector<std::vector<std::vector<float>>>* e); std::vector<std::vector<std::vector<float>>>* e);
private: private:
const Aec3Optimization optimization_; const Aec3Optimization optimization_;
const int sample_rate_hz_; const int sample_rate_hz_;
const size_t num_capture_channels_;
const OouraFft ooura_fft_; const OouraFft ooura_fft_;
const Aec3Fft fft_; const Aec3Fft fft_;
std::vector<std::array<float, kFftLengthBy2>> e_output_old_; std::vector<std::vector<std::array<float, kFftLengthBy2>>> e_output_old_;
RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionFilter); RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionFilter);
}; };

View File

@ -51,46 +51,46 @@ void ProduceSinusoid(int sample_rate_hz,
// Verifies the check for null suppressor output. // Verifies the check for null suppressor output.
TEST(SuppressionFilter, NullOutput) { TEST(SuppressionFilter, NullOutput) {
FftData cn; std::vector<FftData> cn(1);
FftData cn_high_bands; std::vector<FftData> cn_high_bands(1);
FftData E; std::vector<FftData> E(1);
std::array<float, kFftLengthBy2Plus1> gain; std::array<float, kFftLengthBy2Plus1> gain;
EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16000) EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16000, 1)
.ApplyGain(cn, cn_high_bands, gain, 1.0f, E, nullptr), .ApplyGain(cn, cn_high_bands, gain, 1.0f, E, nullptr),
""); "");
} }
// Verifies the check for allowed sample rate. // Verifies the check for allowed sample rate.
TEST(SuppressionFilter, ProperSampleRate) { TEST(SuppressionFilter, ProperSampleRate) {
EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16001), ""); EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16001, 1), "");
} }
#endif #endif
// Verifies that no comfort noise is added when the gain is 1. // Verifies that no comfort noise is added when the gain is 1.
TEST(SuppressionFilter, ComfortNoiseInUnityGain) { TEST(SuppressionFilter, ComfortNoiseInUnityGain) {
SuppressionFilter filter(Aec3Optimization::kNone, 48000); SuppressionFilter filter(Aec3Optimization::kNone, 48000, 1);
FftData cn; std::vector<FftData> cn(1);
FftData cn_high_bands; std::vector<FftData> cn_high_bands(1);
std::array<float, kFftLengthBy2Plus1> gain; std::array<float, kFftLengthBy2Plus1> gain;
std::array<float, kFftLengthBy2> e_old_; std::array<float, kFftLengthBy2> e_old_;
Aec3Fft fft; Aec3Fft fft;
e_old_.fill(0.f); e_old_.fill(0.f);
gain.fill(1.f); gain.fill(1.f);
cn.re.fill(1.f); cn[0].re.fill(1.f);
cn.im.fill(1.f); cn[0].im.fill(1.f);
cn_high_bands.re.fill(1.f); cn_high_bands[0].re.fill(1.f);
cn_high_bands.im.fill(1.f); cn_high_bands[0].im.fill(1.f);
std::vector<std::vector<std::vector<float>>> e( std::vector<std::vector<std::vector<float>>> e(
3, 3,
std::vector<std::vector<float>>(1, std::vector<float>(kBlockSize, 0.f))); std::vector<std::vector<float>>(1, std::vector<float>(kBlockSize, 0.f)));
std::vector<std::vector<std::vector<float>>> e_ref = e; std::vector<std::vector<std::vector<float>>> e_ref = e;
FftData E; std::vector<FftData> E(1);
fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E); fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin()); std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);
@ -110,9 +110,9 @@ TEST(SuppressionFilter, SignalSuppression) {
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
constexpr size_t kNumChannels = 1; constexpr size_t kNumChannels = 1;
SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz); SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1);
FftData cn; std::vector<FftData> cn(1);
FftData cn_high_bands; std::vector<FftData> cn_high_bands(1);
std::array<float, kFftLengthBy2> e_old_; std::array<float, kFftLengthBy2> e_old_;
Aec3Fft fft; Aec3Fft fft;
std::array<float, kFftLengthBy2Plus1> gain; std::array<float, kFftLengthBy2Plus1> gain;
@ -124,10 +124,10 @@ TEST(SuppressionFilter, SignalSuppression) {
gain.fill(1.f); gain.fill(1.f);
std::for_each(gain.begin() + 10, gain.end(), [](float& a) { a = 0.f; }); std::for_each(gain.begin() + 10, gain.end(), [](float& a) { a = 0.f; });
cn.re.fill(0.f); cn[0].re.fill(0.f);
cn.im.fill(0.f); cn[0].im.fill(0.f);
cn_high_bands.re.fill(0.f); cn_high_bands[0].re.fill(0.f);
cn_high_bands.im.fill(0.f); cn_high_bands[0].im.fill(0.f);
size_t sample_counter = 0; size_t sample_counter = 0;
@ -138,8 +138,8 @@ TEST(SuppressionFilter, SignalSuppression) {
e0_input = std::inner_product(e[0][0].begin(), e[0][0].end(), e0_input = std::inner_product(e[0][0].begin(), e[0][0].end(),
e[0][0].begin(), e0_input); e[0][0].begin(), e0_input);
FftData E; std::vector<FftData> E(1);
fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E); fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin()); std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);
@ -157,11 +157,11 @@ TEST(SuppressionFilter, SignalTransparency) {
constexpr int kSampleRateHz = 48000; constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz); SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1);
FftData cn; std::vector<FftData> cn(1);
std::array<float, kFftLengthBy2> e_old_; std::array<float, kFftLengthBy2> e_old_;
Aec3Fft fft; Aec3Fft fft;
FftData cn_high_bands; std::vector<FftData> cn_high_bands(1);
std::array<float, kFftLengthBy2Plus1> gain; std::array<float, kFftLengthBy2Plus1> gain;
std::vector<std::vector<std::vector<float>>> e( std::vector<std::vector<std::vector<float>>> e(
kNumBands, std::vector<std::vector<float>>( kNumBands, std::vector<std::vector<float>>(
@ -170,10 +170,10 @@ TEST(SuppressionFilter, SignalTransparency) {
gain.fill(1.f); gain.fill(1.f);
std::for_each(gain.begin() + 30, gain.end(), [](float& a) { a = 0.f; }); std::for_each(gain.begin() + 30, gain.end(), [](float& a) { a = 0.f; });
cn.re.fill(0.f); cn[0].re.fill(0.f);
cn.im.fill(0.f); cn[0].im.fill(0.f);
cn_high_bands.re.fill(0.f); cn_high_bands[0].re.fill(0.f);
cn_high_bands.im.fill(0.f); cn_high_bands[0].im.fill(0.f);
size_t sample_counter = 0; size_t sample_counter = 0;
@ -184,8 +184,8 @@ TEST(SuppressionFilter, SignalTransparency) {
e0_input = std::inner_product(e[0][0].begin(), e[0][0].end(), e0_input = std::inner_product(e[0][0].begin(), e[0][0].end(),
e[0][0].begin(), e0_input); e[0][0].begin(), e0_input);
FftData E; std::vector<FftData> E(1);
fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E); fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin()); std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);
@ -202,9 +202,9 @@ TEST(SuppressionFilter, Delay) {
constexpr int kSampleRateHz = 48000; constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz); SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1);
FftData cn; std::vector<FftData> cn(1);
FftData cn_high_bands; std::vector<FftData> cn_high_bands(1);
std::array<float, kFftLengthBy2> e_old_; std::array<float, kFftLengthBy2> e_old_;
Aec3Fft fft; Aec3Fft fft;
std::array<float, kFftLengthBy2Plus1> gain; std::array<float, kFftLengthBy2Plus1> gain;
@ -214,10 +214,10 @@ TEST(SuppressionFilter, Delay) {
gain.fill(1.f); gain.fill(1.f);
cn.re.fill(0.f); cn[0].re.fill(0.f);
cn.im.fill(0.f); cn[0].im.fill(0.f);
cn_high_bands.re.fill(0.f); cn_high_bands[0].re.fill(0.f);
cn_high_bands.im.fill(0.f); cn_high_bands[0].im.fill(0.f);
for (size_t k = 0; k < 100; ++k) { for (size_t k = 0; k < 100; ++k) {
for (size_t band = 0; band < kNumBands; ++band) { for (size_t band = 0; band < kNumBands; ++band) {
@ -228,8 +228,8 @@ TEST(SuppressionFilter, Delay) {
} }
} }
FftData E; std::vector<FftData> E(1);
fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E); fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin()); std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);