AEC3: Send the spectral power estimates for all channels to AecState

This CL passes the spectral power estimates for all channels into
the AecState.

Bug: webrtc:10913
Change-Id: Ie3b5c443be0c63f205e23ed2bfea06d9c447eb39
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/156165
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Commit-Queue: Per Åhgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29417}
This commit is contained in:
Per Åhgren
2019-10-09 13:57:07 +02:00
committed by Commit Bot
parent d9755eea22
commit f9807259a6
9 changed files with 105 additions and 58 deletions

View File

@ -376,15 +376,20 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) {
FftData S; FftData S;
FftData G; FftData G;
FftData E; FftData E;
std::array<float, kFftLengthBy2Plus1> Y2; std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels);
std::array<float, kFftLengthBy2Plus1> E2_main; std::vector<std::array<float, kFftLengthBy2Plus1>> E2_main(
kNumCaptureChannels);
std::array<float, kFftLengthBy2Plus1> E2_shadow; std::array<float, kFftLengthBy2Plus1> E2_shadow;
// [B,A] = butter(2,100/8000,'high') // [B,A] = butter(2,100/8000,'high')
constexpr CascadedBiQuadFilter::BiQuadCoefficients constexpr CascadedBiQuadFilter::BiQuadCoefficients
kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f}, kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f},
{-1.94448f, 0.94598f}}; {-1.94448f, 0.94598f}};
Y2.fill(0.f); for (auto& Y2_ch : Y2) {
E2_main.fill(0.f); Y2_ch.fill(0.f);
}
for (auto& E2_main_ch : E2_main) {
E2_main_ch.fill(0.f);
}
E2_shadow.fill(0.f); E2_shadow.fill(0.f);
for (auto& subtractor_output : output) { for (auto& subtractor_output : output) {
subtractor_output.Reset(); subtractor_output.Reset();

View File

@ -164,10 +164,12 @@ void AecState::Update(
adaptive_filter_frequency_response, adaptive_filter_frequency_response,
rtc::ArrayView<const std::vector<float>> adaptive_filter_impulse_response, rtc::ArrayView<const std::vector<float>> adaptive_filter_impulse_response,
const RenderBuffer& render_buffer, const RenderBuffer& render_buffer,
const std::array<float, kFftLengthBy2Plus1>& E2_main, rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2_main,
const std::array<float, kFftLengthBy2Plus1>& Y2, rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const SubtractorOutput> subtractor_output) { rtc::ArrayView<const SubtractorOutput> subtractor_output) {
const size_t num_capture_channels = filter_analyzers_.size(); const size_t num_capture_channels = filter_analyzers_.size();
RTC_DCHECK_EQ(num_capture_channels, E2_main.size());
RTC_DCHECK_EQ(num_capture_channels, Y2.size());
RTC_DCHECK_EQ(num_capture_channels, subtractor_output.size()); RTC_DCHECK_EQ(num_capture_channels, subtractor_output.size());
RTC_DCHECK_EQ(num_capture_channels, subtractor_output_analyzers_.size()); RTC_DCHECK_EQ(num_capture_channels, subtractor_output_analyzers_.size());
RTC_DCHECK_EQ(num_capture_channels, RTC_DCHECK_EQ(num_capture_channels,
@ -244,12 +246,12 @@ void AecState::Update(
const auto& X2_input_erle = X2_reverb; const auto& X2_input_erle = X2_reverb;
erle_estimator_.Update(render_buffer, adaptive_filter_frequency_response[0], erle_estimator_.Update(render_buffer, adaptive_filter_frequency_response[0],
X2_input_erle, Y2, E2_main, X2_input_erle, Y2[0], E2_main[0],
subtractor_output_analyzers_[0].ConvergedFilter(), subtractor_output_analyzers_[0].ConvergedFilter(),
config_.erle.onset_detection); config_.erle.onset_detection);
erl_estimator_.Update(subtractor_output_analyzers_[0].ConvergedFilter(), X2, erl_estimator_.Update(subtractor_output_analyzers_[0].ConvergedFilter(), X2,
Y2); Y2[0]);
// Detect and flag echo saturation. // Detect and flag echo saturation.
saturation_detector_.Update(aligned_render_block, SaturatedCapture(), saturation_detector_.Update(aligned_render_block, SaturatedCapture(),

View File

@ -133,8 +133,8 @@ class AecState {
adaptive_filter_frequency_response, adaptive_filter_frequency_response,
rtc::ArrayView<const std::vector<float>> adaptive_filter_impulse_response, rtc::ArrayView<const std::vector<float>> adaptive_filter_impulse_response,
const RenderBuffer& render_buffer, const RenderBuffer& render_buffer,
const std::array<float, kFftLengthBy2Plus1>& E2_main, rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2_main,
const std::array<float, kFftLengthBy2Plus1>& Y2, rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
rtc::ArrayView<const SubtractorOutput> subtractor_output); rtc::ArrayView<const SubtractorOutput> subtractor_output);
// Returns filter length in blocks. // Returns filter length in blocks.

View File

@ -39,8 +39,9 @@ void RunNormalUsageTest(size_t num_render_channels,
DelayEstimate(DelayEstimate::Quality::kRefined, 10); DelayEstimate(DelayEstimate::Quality::kRefined, 10);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer( std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
std::array<float, kFftLengthBy2Plus1> E2_main = {}; std::vector<std::array<float, kFftLengthBy2Plus1>> E2_main(
std::array<float, kFftLengthBy2Plus1> Y2 = {}; num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
std::vector<std::vector<std::vector<float>>> x( std::vector<std::vector<std::vector<float>>> x(
kNumBands, std::vector<std::vector<float>>( kNumBands, std::vector<std::vector<float>>(
num_render_channels, std::vector<float>(kBlockSize, 0.f))); num_render_channels, std::vector<float>(kBlockSize, 0.f)));
@ -53,6 +54,8 @@ void RunNormalUsageTest(size_t num_render_channels,
subtractor_output[ch].s_main.fill(100.f); subtractor_output[ch].s_main.fill(100.f);
subtractor_output[ch].e_main.fill(100.f); subtractor_output[ch].e_main.fill(100.f);
y[ch].fill(1000.f); y[ch].fill(1000.f);
E2_main[ch].fill(0.f);
Y2[ch].fill(0.f);
} }
Aec3Fft fft; Aec3Fft fft;
std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>>
@ -143,7 +146,9 @@ void RunNormalUsageTest(size_t num_render_channels,
render_delay_buffer->PrepareCaptureProcessing(); render_delay_buffer->PrepareCaptureProcessing();
} }
Y2.fill(10.f * 10000.f * 10000.f); for (auto& Y2_ch : Y2) {
Y2_ch.fill(10.f * 10000.f * 10000.f);
}
for (size_t k = 0; k < 1000; ++k) { for (size_t k = 0; k < 1000; ++k) {
for (size_t ch = 0; ch < num_capture_channels; ++ch) { for (size_t ch = 0; ch < num_capture_channels; ++ch) {
subtractor_output[ch].ComputeMetrics(y[ch]); subtractor_output[ch].ComputeMetrics(y[ch]);
@ -162,8 +167,12 @@ void RunNormalUsageTest(size_t num_render_channels,
EXPECT_EQ(erl[erl.size() - 2], erl[erl.size() - 1]); EXPECT_EQ(erl[erl.size() - 2], erl[erl.size() - 1]);
// Verify that the ERLE is properly estimated // Verify that the ERLE is properly estimated
E2_main.fill(1.f * 10000.f * 10000.f); for (auto& E2_main_ch : E2_main) {
Y2.fill(10.f * E2_main[0]); E2_main_ch.fill(1.f * 10000.f * 10000.f);
}
for (auto& Y2_ch : Y2) {
Y2_ch.fill(10.f * E2_main[0][0]);
}
for (size_t k = 0; k < 1000; ++k) { for (size_t k = 0; k < 1000; ++k) {
for (size_t ch = 0; ch < num_capture_channels; ++ch) { for (size_t ch = 0; ch < num_capture_channels; ++ch) {
subtractor_output[ch].ComputeMetrics(y[ch]); subtractor_output[ch].ComputeMetrics(y[ch]);
@ -187,9 +196,12 @@ void RunNormalUsageTest(size_t num_render_channels,
} }
EXPECT_EQ(erle[erle.size() - 2], erle[erle.size() - 1]); EXPECT_EQ(erle[erle.size() - 2], erle[erle.size() - 1]);
} }
for (auto& E2_main_ch : E2_main) {
E2_main.fill(1.f * 10000.f * 10000.f); E2_main_ch.fill(1.f * 10000.f * 10000.f);
Y2.fill(5.f * E2_main[0]); }
for (auto& Y2_ch : Y2) {
Y2_ch.fill(5.f * E2_main[0][0]);
}
for (size_t k = 0; k < 1000; ++k) { for (size_t k = 0; k < 1000; ++k) {
for (size_t ch = 0; ch < num_capture_channels; ++ch) { for (size_t ch = 0; ch < num_capture_channels; ++ch) {
subtractor_output[ch].ComputeMetrics(y[ch]); subtractor_output[ch].ComputeMetrics(y[ch]);
@ -235,8 +247,9 @@ TEST(AecState, ConvergedFilterDelay) {
std::unique_ptr<RenderDelayBuffer> render_delay_buffer( std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, 48000, 1)); RenderDelayBuffer::Create(config, 48000, 1));
absl::optional<DelayEstimate> delay_estimate; absl::optional<DelayEstimate> delay_estimate;
std::array<float, kFftLengthBy2Plus1> E2_main; std::vector<std::array<float, kFftLengthBy2Plus1>> E2_main(
std::array<float, kFftLengthBy2Plus1> Y2; kNumCaptureChannels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels);
std::array<float, kBlockSize> x; std::array<float, kBlockSize> x;
EchoPathVariability echo_path_variability( EchoPathVariability echo_path_variability(
false, EchoPathVariability::DelayAdjustment::kNone, false); false, EchoPathVariability::DelayAdjustment::kNone, false);

View File

@ -385,8 +385,8 @@ void EchoRemoverImpl::ProcessCapture(
// Update the AEC state information. // Update the AEC state information.
// TODO(bugs.webrtc.org/10913): Take all subtractors into account. // TODO(bugs.webrtc.org/10913): Take all subtractors into account.
aec_state_.Update(external_delay, subtractor_.FilterFrequencyResponse(), aec_state_.Update(external_delay, subtractor_.FilterFrequencyResponse(),
subtractor_.FilterImpulseResponse(), *render_buffer, E2[0], subtractor_.FilterImpulseResponse(), *render_buffer, E2, Y2,
Y2[0], subtractor_output); subtractor_output);
// Choose the linear output. // Choose the linear output.
const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E : Y; const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E : Y;

View File

@ -44,7 +44,8 @@ void RunFilterUpdateTest(int num_blocks_to_process,
FftData* G_last_block) { FftData* G_last_block) {
ApmDataDumper data_dumper(42); ApmDataDumper data_dumper(42);
Aec3Optimization optimization = DetectOptimization(); Aec3Optimization optimization = DetectOptimization();
constexpr size_t kNumChannels = 1; constexpr size_t kNumRenderChannels = 1;
constexpr size_t kNumCaptureChannels = 1;
constexpr int kSampleRateHz = 48000; constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
@ -60,16 +61,16 @@ void RunFilterUpdateTest(int num_blocks_to_process,
config.filter.config_change_duration_blocks, config.filter.config_change_duration_blocks,
1, optimization, &data_dumper); 1, optimization, &data_dumper);
std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2( std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2(
kNumChannels, std::vector<std::array<float, kFftLengthBy2Plus1>>( kNumCaptureChannels, std::vector<std::array<float, kFftLengthBy2Plus1>>(
main_filter.max_filter_size_partitions(), main_filter.max_filter_size_partitions(),
std::array<float, kFftLengthBy2Plus1>())); std::array<float, kFftLengthBy2Plus1>()));
for (auto& H2_ch : H2) { for (auto& H2_ch : H2) {
for (auto& H2_k : H2_ch) { for (auto& H2_k : H2_ch) {
H2_k.fill(0.f); H2_k.fill(0.f);
} }
} }
std::vector<std::vector<float>> h( std::vector<std::vector<float>> h(
kNumChannels, kNumCaptureChannels,
std::vector<float>( std::vector<float>(
GetTimeDomainLength(main_filter.max_filter_size_partitions()), 0.f)); GetTimeDomainLength(main_filter.max_filter_size_partitions()), 0.f));
@ -83,29 +84,32 @@ void RunFilterUpdateTest(int num_blocks_to_process,
Random random_generator(42U); Random random_generator(42U);
std::vector<std::vector<std::vector<float>>> x( std::vector<std::vector<std::vector<float>>> x(
kNumBands, std::vector<std::vector<float>>( kNumBands, std::vector<std::vector<float>>(
kNumChannels, std::vector<float>(kBlockSize, 0.f))); kNumRenderChannels, std::vector<float>(kBlockSize, 0.f)));
std::vector<float> y(kBlockSize, 0.f); std::vector<float> y(kBlockSize, 0.f);
config.delay.default_delay = 1; config.delay.default_delay = 1;
std::unique_ptr<RenderDelayBuffer> render_delay_buffer( std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels)); RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels));
AecState aec_state(config, kNumChannels); AecState aec_state(config, kNumCaptureChannels);
RenderSignalAnalyzer render_signal_analyzer(config); RenderSignalAnalyzer render_signal_analyzer(config);
absl::optional<DelayEstimate> delay_estimate; absl::optional<DelayEstimate> delay_estimate;
std::array<float, kFftLength> s_scratch; std::array<float, kFftLength> s_scratch;
std::array<float, kBlockSize> s; std::array<float, kBlockSize> s;
FftData S; FftData S;
FftData G; FftData G;
std::vector<SubtractorOutput> output(kNumChannels); std::vector<SubtractorOutput> output(kNumCaptureChannels);
for (auto& subtractor_output : output) { for (auto& subtractor_output : output) {
subtractor_output.Reset(); subtractor_output.Reset();
} }
FftData& E_main = output[0].E_main; FftData& E_main = output[0].E_main;
FftData E_shadow; FftData E_shadow;
std::array<float, kFftLengthBy2Plus1> Y2; std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels);
std::array<float, kFftLengthBy2Plus1>& E2_main = output[0].E2_main; std::vector<std::array<float, kFftLengthBy2Plus1>> E2_main(
kNumCaptureChannels);
std::array<float, kBlockSize>& e_main = output[0].e_main; std::array<float, kBlockSize>& e_main = output[0].e_main;
std::array<float, kBlockSize>& e_shadow = output[0].e_shadow; std::array<float, kBlockSize>& e_shadow = output[0].e_shadow;
Y2.fill(0.f); for (auto& Y2_ch : Y2) {
Y2_ch.fill(0.f);
}
constexpr float kScale = 1.0f / kFftLengthBy2; constexpr float kScale = 1.0f / kFftLengthBy2;
@ -197,6 +201,8 @@ void RunFilterUpdateTest(int num_blocks_to_process,
aec_state.HandleEchoPathChange(EchoPathVariability( aec_state.HandleEchoPathChange(EchoPathVariability(
false, EchoPathVariability::DelayAdjustment::kNone, false)); false, EchoPathVariability::DelayAdjustment::kNone, false));
main_filter.ComputeFrequencyResponse(&H2[0]); main_filter.ComputeFrequencyResponse(&H2[0]);
std::copy(output[0].E2_main.begin(), output[0].E2_main.end(),
E2_main[0].begin());
aec_state.Update(delay_estimate, H2, h, aec_state.Update(delay_estimate, H2, h,
*render_delay_buffer->GetRenderBuffer(), E2_main, Y2, *render_delay_buffer->GetRenderBuffer(), E2_main, Y2,
output); output);

View File

@ -33,7 +33,8 @@ TEST(ResidualEchoEstimator, BasicTest) {
RenderDelayBuffer::Create(config, kSampleRateHz, RenderDelayBuffer::Create(config, kSampleRateHz,
num_render_channels)); num_render_channels));
std::array<float, kFftLengthBy2Plus1> E2_main; std::vector<std::array<float, kFftLengthBy2Plus1>> E2_main(
num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> S2_linear( std::vector<std::array<float, kFftLengthBy2Plus1>> S2_linear(
num_capture_channels); num_capture_channels);
std::vector<std::array<float, kFftLengthBy2Plus1>> Y2( std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(
@ -72,9 +73,13 @@ TEST(ResidualEchoEstimator, BasicTest) {
y.fill(0.f); y.fill(0.f);
constexpr float kLevel = 10.f; constexpr float kLevel = 10.f;
E2_main.fill(kLevel); for (auto& E2_main_ch : E2_main) {
E2_main_ch.fill(kLevel);
}
S2_linear[0].fill(kLevel); S2_linear[0].fill(kLevel);
Y2[0].fill(kLevel); for (auto& Y2_ch : Y2) {
Y2_ch.fill(kLevel);
}
for (int k = 0; k < 1993; ++k) { for (int k = 0; k < 1993; ++k) {
RandomizeSampleVector(&random_generator, x[0][0]); RandomizeSampleVector(&random_generator, x[0][0]);
@ -85,8 +90,8 @@ TEST(ResidualEchoEstimator, BasicTest) {
render_delay_buffer->PrepareCaptureProcessing(); render_delay_buffer->PrepareCaptureProcessing();
aec_state.Update(delay_estimate, H2, h, aec_state.Update(delay_estimate, H2, h,
*render_delay_buffer->GetRenderBuffer(), E2_main, *render_delay_buffer->GetRenderBuffer(), E2_main, Y2,
Y2[0], output); output);
estimator.Estimate(aec_state, *render_delay_buffer->GetRenderBuffer(), estimator.Estimate(aec_state, *render_delay_buffer->GetRenderBuffer(),
S2_linear, Y2, R2); S2_linear, Y2, R2);

View File

@ -58,13 +58,18 @@ std::vector<float> RunSubtractorTest(
RenderSignalAnalyzer render_signal_analyzer(config); RenderSignalAnalyzer render_signal_analyzer(config);
Random random_generator(42U); Random random_generator(42U);
Aec3Fft fft; Aec3Fft fft;
std::array<float, kFftLengthBy2Plus1> Y2; std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
std::array<float, kFftLengthBy2Plus1> E2_main; std::vector<std::array<float, kFftLengthBy2Plus1>> E2_main(
num_capture_channels);
std::array<float, kFftLengthBy2Plus1> E2_shadow; std::array<float, kFftLengthBy2Plus1> E2_shadow;
AecState aec_state(config, num_capture_channels); AecState aec_state(config, num_capture_channels);
x_old.fill(0.f); x_old.fill(0.f);
Y2.fill(0.f); for (auto& Y2_ch : Y2) {
E2_main.fill(0.f); Y2_ch.fill(0.f);
}
for (auto& E2_main_ch : E2_main) {
E2_main_ch.fill(0.f);
}
E2_shadow.fill(0.f); E2_shadow.fill(0.f);
std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> delay_buffer( std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> delay_buffer(

View File

@ -58,35 +58,40 @@ TEST(SuppressionGain, NullOutputGains) {
// Does a sanity check that the gains are correctly computed. // Does a sanity check that the gains are correctly computed.
TEST(SuppressionGain, BasicGainComputation) { TEST(SuppressionGain, BasicGainComputation) {
constexpr size_t kNumChannels = 1; constexpr size_t kNumRenderChannels = 1;
constexpr size_t kNumCaptureChannels = 1;
constexpr int kSampleRateHz = 16000; constexpr int kSampleRateHz = 16000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
SuppressionGain suppression_gain(EchoCanceller3Config(), DetectOptimization(), SuppressionGain suppression_gain(EchoCanceller3Config(), DetectOptimization(),
kSampleRateHz); kSampleRateHz);
RenderSignalAnalyzer analyzer(EchoCanceller3Config{}); RenderSignalAnalyzer analyzer(EchoCanceller3Config{});
float high_bands_gain; float high_bands_gain;
std::array<float, kFftLengthBy2Plus1> E2; std::vector<std::array<float, kFftLengthBy2Plus1>> E2(kNumCaptureChannels);
std::array<float, kFftLengthBy2Plus1> S2; std::array<float, kFftLengthBy2Plus1> S2;
std::array<float, kFftLengthBy2Plus1> Y2; std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels);
std::array<float, kFftLengthBy2Plus1> R2; std::array<float, kFftLengthBy2Plus1> R2;
std::array<float, kFftLengthBy2Plus1> N2; std::array<float, kFftLengthBy2Plus1> N2;
std::array<float, kFftLengthBy2Plus1> g; std::array<float, kFftLengthBy2Plus1> g;
std::vector<SubtractorOutput> output(kNumChannels); std::vector<SubtractorOutput> output(kNumCaptureChannels);
std::array<float, kBlockSize> y; std::array<float, kBlockSize> y;
std::vector<std::vector<std::vector<float>>> x( std::vector<std::vector<std::vector<float>>> x(
kNumBands, std::vector<std::vector<float>>( kNumBands, std::vector<std::vector<float>>(
kNumChannels, std::vector<float>(kBlockSize, 0.f))); kNumRenderChannels, std::vector<float>(kBlockSize, 0.f)));
EchoCanceller3Config config; EchoCanceller3Config config;
AecState aec_state(config, kNumChannels); AecState aec_state(config, kNumCaptureChannels);
ApmDataDumper data_dumper(42); ApmDataDumper data_dumper(42);
Subtractor subtractor(config, 1, 1, &data_dumper, DetectOptimization()); Subtractor subtractor(config, 1, 1, &data_dumper, DetectOptimization());
std::unique_ptr<RenderDelayBuffer> render_delay_buffer( std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels)); RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels));
absl::optional<DelayEstimate> delay_estimate; absl::optional<DelayEstimate> delay_estimate;
// Ensure that a strong noise is detected to mask any echoes. // Ensure that a strong noise is detected to mask any echoes.
E2.fill(10.f); for (auto& E2_k : E2) {
Y2.fill(10.f); E2_k.fill(10.f);
}
for (auto& Y2_k : Y2) {
Y2_k.fill(10.f);
}
R2.fill(0.1f); R2.fill(0.1f);
S2.fill(0.1f); S2.fill(0.1f);
N2.fill(100.f); N2.fill(100.f);
@ -106,15 +111,19 @@ TEST(SuppressionGain, BasicGainComputation) {
aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(), aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(),
subtractor.FilterImpulseResponse(), subtractor.FilterImpulseResponse(),
*render_delay_buffer->GetRenderBuffer(), E2, Y2, output); *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x, suppression_gain.GetGain(E2[0], S2, R2, N2, analyzer, aec_state, x,
&high_bands_gain, &g); &high_bands_gain, &g);
} }
std::for_each(g.begin(), g.end(), std::for_each(g.begin(), g.end(),
[](float a) { EXPECT_NEAR(1.f, a, 0.001); }); [](float a) { EXPECT_NEAR(1.f, a, 0.001); });
// Ensure that a strong nearend is detected to mask any echoes. // Ensure that a strong nearend is detected to mask any echoes.
E2.fill(100.f); for (auto& E2_k : E2) {
Y2.fill(100.f); E2_k.fill(100.f);
}
for (auto& Y2_k : Y2) {
Y2_k.fill(100.f);
}
R2.fill(0.1f); R2.fill(0.1f);
S2.fill(0.1f); S2.fill(0.1f);
N2.fill(0.f); N2.fill(0.f);
@ -123,18 +132,20 @@ TEST(SuppressionGain, BasicGainComputation) {
aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(), aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(),
subtractor.FilterImpulseResponse(), subtractor.FilterImpulseResponse(),
*render_delay_buffer->GetRenderBuffer(), E2, Y2, output); *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x, suppression_gain.GetGain(E2[0], S2, R2, N2, analyzer, aec_state, x,
&high_bands_gain, &g); &high_bands_gain, &g);
} }
std::for_each(g.begin(), g.end(), std::for_each(g.begin(), g.end(),
[](float a) { EXPECT_NEAR(1.f, a, 0.001); }); [](float a) { EXPECT_NEAR(1.f, a, 0.001); });
// Ensure that a strong echo is suppressed. // Ensure that a strong echo is suppressed.
E2.fill(1000000000.f); for (auto& E2_k : E2) {
E2_k.fill(1000000000.f);
}
R2.fill(10000000000000.f); R2.fill(10000000000000.f);
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x, suppression_gain.GetGain(E2[0], S2, R2, N2, analyzer, aec_state, x,
&high_bands_gain, &g); &high_bands_gain, &g);
} }
std::for_each(g.begin(), g.end(), std::for_each(g.begin(), g.end(),